refactor(all): kag v0.6 (#174)

* add path find

* fix find path

* spg guided relation extraction

* fix dict parse with same key

* rename graphalgoclient to graphclient

* rename graphalgoclient to graphclient

* file reader supports http url

* add checkpointer class

* parser supports checkpoint

* add build

* remove incorrect logs

* remove logs

* update examples

* update chain checkpointer

* vectorizer batch size set to 32

* add a zodb backended checkpointer

* add a zodb backended checkpointer

* fix zodb based checkpointer

* add thread for zodb IO

* fix(common): resolve mutlithread conflict in zodb IO

* fix(common): load existing zodb checkpoints

* update examples

* update examples

* fix zodb writer

* add docstring

* fix jieba version mismatch

* commit kag_config-tc.yaml

1、rename type to register_name
2、put a uniqe & specific name to register_name
3、rename reader to scanner
4、rename parser to reader
5、rename num_parallel to num_parallel_file, rename chain_level_num_paralle to num_parallel_chain_of_file
6、rename kag_extractor to schema_free_extractor, schema_base_extractor to schema_constraint_extractor
7、pre-define llm & vectorize_model and refer them in the yaml file

Issues to be resolved:
1、examples of event extract & spg extract
2、statistic of indexer, such as nums of nodes & edges extracted, ratio of llm invoke.
3、Exceptions such as Debt, account does not exist should be thrown in llm invoke.
4、conf of solver need to be re-examined.

* commit kag_config-tc.yaml

1、rename type to register_name
2、put a uniqe & specific name to register_name
3、rename reader to scanner
4、rename parser to reader
5、rename num_parallel to num_parallel_file, rename chain_level_num_paralle to num_parallel_chain_of_file
6、rename kag_extractor to schema_free_extractor, schema_base_extractor to schema_constraint_extractor
7、pre-define llm & vectorize_model and refer them in the yaml file

Issues to be resolved:
1、examples of event extract & spg extract
2、statistic of indexer, such as nums of nodes & edges extracted, ratio of llm invoke.
3、Exceptions such as Debt, account does not exist should be thrown in llm invoke.
4、conf of solver need to be re-examined.

* 1、fix bug in base_table_splitter

* 1、fix bug in base_table_splitter

* 1、fix bug in default_chain

* 增加solver

* add kag

* update outline splitter

* add main test

* add op

* code refactor

* add tools

* fix outline splitter

* fix outline prompt

* graph api pass

* commit with page rank

* add search api and graph api

* add markdown report

* fix vectorizer num batch compute

* add retry for vectorize model call

* update markdown reader

* update markdown reader

* update pdf reader

* raise extractor failure

* add default expr

* add log

* merge jc reader features

* rm import

* add build

* fix zodb based checkpointer

* add thread for zodb IO

* fix(common): resolve mutlithread conflict in zodb IO

* fix(common): load existing zodb checkpoints

* update examples

* update examples

* fix zodb writer

* add docstring

* fix jieba version mismatch

* commit kag_config-tc.yaml

1、rename type to register_name
2、put a uniqe & specific name to register_name
3、rename reader to scanner
4、rename parser to reader
5、rename num_parallel to num_parallel_file, rename chain_level_num_paralle to num_parallel_chain_of_file
6、rename kag_extractor to schema_free_extractor, schema_base_extractor to schema_constraint_extractor
7、pre-define llm & vectorize_model and refer them in the yaml file

Issues to be resolved:
1、examples of event extract & spg extract
2、statistic of indexer, such as nums of nodes & edges extracted, ratio of llm invoke.
3、Exceptions such as Debt, account does not exist should be thrown in llm invoke.
4、conf of solver need to be re-examined.

* commit kag_config-tc.yaml

1、rename type to register_name
2、put a uniqe & specific name to register_name
3、rename reader to scanner
4、rename parser to reader
5、rename num_parallel to num_parallel_file, rename chain_level_num_paralle to num_parallel_chain_of_file
6、rename kag_extractor to schema_free_extractor, schema_base_extractor to schema_constraint_extractor
7、pre-define llm & vectorize_model and refer them in the yaml file

Issues to be resolved:
1、examples of event extract & spg extract
2、statistic of indexer, such as nums of nodes & edges extracted, ratio of llm invoke.
3、Exceptions such as Debt, account does not exist should be thrown in llm invoke.
4、conf of solver need to be re-examined.

* 1、fix bug in base_table_splitter

* 1、fix bug in base_table_splitter

* 1、fix bug in default_chain

* update outline splitter

* add main test

* add markdown report

* code refactor

* fix outline splitter

* fix outline prompt

* update markdown reader

* fix vectorizer num batch compute

* add retry for vectorize model call

* update markdown reader

* raise extractor failure

* rm parser

* run pipeline

* add config option of whether to perform llm config check, default to false

* fix

* recover pdf reader

* several components can be null for default chain

* 支持完整qa运行

* add if

* remove unused code

* 使用chunk兜底

* excluded source relation to choose

* add generate

* default recall 10

* add local memory

* 排除相似边

* 增加保护

* 修复并发问题

* add debug logger

* 支持topk参数化

* 支持chunk截断和调整spo select 的prompt

* 增加查询请求保护

* 增加force_chunk配置

* fix entity linker algorithm

* 增加sub query改写

* fix md reader dup in test

* fix

* merge knext to kag parallel

* fix package

* 修复指标下跌问题

* scanner update

* scanner update

* add doc and update example scripts

* fix

* add bridge to spg server

* add format

* fix bridge

* update conf for baike

* disable ckpt for spg server runner

* llm invoke error default raise exceptions

* chore(version): bump version to X.Y.Z

* update default response generation prompt

* add method getSummarizationMetrics

* fix(common): fix project conf empty error

* fix typo

* 增加上报信息

* 修改main solver

* postprocessor support spg server

* 修改solver支持名

* fix language

* 修改chunker接口,增加openapi

* rename vectorizer to vectorize_model in spg server config

* generate_random_string start with gen

* add knext llm vector checker

* add knext llm vector checker

* add knext llm vector checker

* solver移除默认值

* udpate yaml and register_name for baike

* udpate yaml and register_name for baike

* remove config key check

* 修复llmmodule

* fix knext project

* udpate yaml and register_name for examples

* udpate yaml and register_name for examples

* Revert "udpate yaml and register_name for examples"

This reverts commit b3fa5ca9ba.

* update register name

* fix

* fix

* support multiple resigter names

* update component

* update reader register names (#183)

* fix markdown reader

* fix llm client for retry

* feat(common): add processed chunk id checkpoint (#185)

* update reader register names

* add processed chunk id checkpoint

* feat(example): add example config (#186)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* add max_workers parameter for getSummarizationMetrics to make it faster

* add csqa data generation script generate_data.py

* commit generated csqa builder and solver data

* add csqa basic project files

* adjust split_length and num_threads_per_chain to match lightrag settings

* ignore ckpt dirs

* add csqa evaluation script eval.py

* save evaluation scripts summarization_metrics.py and factual_correctness.py

* save LightRAG output csqa_lightrag_answers.json

* ignore KAG output csqa_kag_answers.json

* add README.md for CSQA

* fix(solver): fix solver pipeline conf (#191)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* update solver pipeline config

* fix project create

* update links and file paths

* reformat csqa kag_config.yaml

* reformat csqa python files

* reformat getSummarizationMetrics and compare_summarization_answers

* fix(solver): fix solver config (#192)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* update solver pipeline config

* fix project create

* fix main solver conf

* add except

* fix typo in csqa README.md

* feat(conf): support reinitialize config for call from java side (#199)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* update solver pipeline config

* fix project create

* fix main solver conf

* support reinitialize config for java call

* revert default response generation prompt

* update project list

* add README.md for the hotpotqa, 2wiki and musique examples

* 增加spo检索

* turn off kag config dump by default

* turn off knext schema dump by default

* add .gitignore and fix kag_config.yaml

* add README.md for the medicine example

* add README.md for the supplychain example

* bugfix for risk mining

* use exact out

* refactor(solver): format solver code (#205)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* update solver pipeline config

* fix project create

* fix main solver conf

* support reinitialize config for java call

* black format

---------

Co-authored-by: peilong <peilong.zpl@antgroup.com>
Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com>
Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>
Co-authored-by: huaidong.xhd <huaidong.xhd@antgroup.com>
This commit is contained in:
zhuzhongshu123 2025-01-03 17:10:51 +08:00 committed by GitHub
parent 9dabe07282
commit e1d818dfaa
678 changed files with 215065 additions and 14479 deletions

34
.github/workflows/code-check.yml vendored Normal file
View File

@ -0,0 +1,34 @@
name: CI
on:
push:
pull_request:
workflow_dispatch:
repository_dispatch:
types: [my_event]
jobs:
format-check:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit pytest pytest-cov
pip install -r requirements.txt
pip install -e .
pip install black==24.10.0
- name: Run pre-commit
run: pre-commit run --all-files
# - name: Run unit tests
# run: pushd tests/unit && pytest && popd

28
.github/workflows/pr-title-check.yml vendored Normal file
View File

@ -0,0 +1,28 @@
name: "Lint PR"
on:
pull_request_target:
types:
- opened
- edited
- synchronize
jobs:
main:
name: Validate PR title
runs-on: ubuntu-latest
steps:
# https://www.conventionalcommits.org/en/v1.0.0/#summary
- uses: amannn/action-semantic-pull-request@v5
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
requireScope: true
subjectPattern: ^(?![A-Z]).+$
# If `subjectPattern` is configured, you can use this property to override
# the default error message that is shown when the pattern doesn't match.
# The variables `subject` and `title` can be used within the message.
subjectPatternError: |
The subject "{subject}" found in the pull request title "{title}"
didn't match the configured pattern. Please ensure that the subject
doesn't start with an uppercase character.

2
.gitignore vendored
View File

@ -12,4 +12,6 @@
*.pyc
/dist
.vscode/
.idea/
.venv/
__pycache__/

17
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,17 @@
repos:
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
files: ^kag/.*\.py$
exclude: |
(?x)^(
kag/solver/logic/core_modules/rule_runner/rule_runner.py |
kag/solver/logic/core_modules/parser/logic_node_parser.py
)$
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
hooks:
- id: flake8
files: ^kag/.*\.py$

View File

@ -1 +1 @@
0.5.2-beta1
0.6

View File

@ -1,2 +1,6 @@
recursive-include kag *
recursive-exclude kag/examples *
global-exclude *.pyc
global-exclude *.pyo
global-exclude *.pyd
global-exclude __pycache__

5
build.sh Normal file
View File

@ -0,0 +1,5 @@
rm -rf build
rm -rf dist
python setup.py sdist bdist_wheel

View File

@ -1,3 +1,4 @@
# flake8: noqa
# Apache License
# Version 2.0, January 2004
# http://www.apache.org/licenses/
@ -202,8 +203,27 @@
__package_name__ = "openspg-kag"
__version__ = "0.5.2-beta1"
__version__ = "0.6"
from kag.common.env import init_env
# Register Built-in Components
from kag.common.conf import init_env
init_env()
import kag.interface
import kag.interface.solver.execute
import kag.interface.solver.plan
import kag.solver.execute
import kag.solver.plan
import kag.solver.retriever
import kag.solver.tools
import kag.builder.component
import kag.builder.default_chain
import kag.builder.runner
import kag.builder.prompt
import kag.solver.prompt
import kag.common.vectorize_model
import kag.common.llm
import kag.common.checkpointer
import kag.solver
import kag.bin.commands

49
kag/bin/base.py Normal file
View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import argparse
import logging
import typing
from kag.common.registry import Registrable
logger = logging.getLogger()
def add_commands(
subparsers: argparse._SubParsersAction, command_names: typing.List[str] = None
):
"""add commands to subparsers"""
all_cmds = Command.list_available()
if command_names is None:
logger.warn("invalid command_names, will add all available commands.")
command_names = all_cmds
for cmd in command_names:
if cmd not in all_cmds:
raise ValueError(f"command {cmd} not in available commands {all_cmds}")
# Command Subclasses doesn't accept init args, so just pass subclass name is OK.
cls = Command.from_config(cmd)
cls.add_to_parser(subparsers)
class Command(Registrable):
def get_handler(self):
"""return handler of current command"""
return self.handler
def add_to_parser(self, subparsers: argparse._SubParsersAction):
"""setup accept arguments"""
raise NotImplementedError("setup_parser not implemented yet.")
@staticmethod
def handler(args: argparse.Namespace):
"""function to proces the request."""
raise NotImplementedError("handler not implemented yet.")

View File

@ -9,10 +9,7 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.bin.commands.info import ListRegisterInfo
from kag.common.retriever.kag_retriever import DefaultRetriever
from kag.common.retriever.retriever import Retriever
__all__ = [
"DefaultRetriever",
"Retriever"
]
__all__ = ["ListRegisterInfo"]

115
kag/bin/commands/info.py Normal file
View File

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import argparse
from tabulate import tabulate
from kag.bin.base import Command
from kag.common.registry import Registrable
from kag.common.utils import reset, bold, red, green, blue
@Command.register("register_info")
class ListRegisterInfo(Command):
def add_to_parser(self, subparsers: argparse._SubParsersAction):
parser = subparsers.add_parser(
"interface", help="Show the interface info of the KAG components."
)
parser.add_argument("--cls", help="class name to query")
parser.add_argument(
"--list", help="list all component interfaces in KAG", action="store_true"
)
parser.set_defaults(func=self.get_handler())
@staticmethod
def get_cls(cls_name):
interface_classes = Registrable.list_all_registered(with_leaf_classes=False)
for item in interface_classes:
if item.__name__ == cls_name:
return item
raise ValueError(f"class {cls_name} is not a valid kag configurable class")
@staticmethod
def handle_list(args: argparse.Namespace):
interface_classes = Registrable.list_all_registered(with_leaf_classes=False)
data = []
for cls in interface_classes:
data.append([cls.__name__, cls.__module__])
headers = [f"{bold}{red}class{reset}", f"{bold}{red}module{reset}"]
msg = (
f"{bold}{red}Below are the interfaces provided by KAG."
f"For detailed information on each class, please use the command `kag interface --cls $class_name`{reset}"
)
print(msg)
print(tabulate(data, headers, tablefmt="grid"))
@staticmethod
def handle_cls(args: argparse.Namespace):
cls_obj = ListRegisterInfo.get_cls(args.cls)
if not issubclass(cls_obj, Registrable):
raise ValueError(f"class {args.cls} is not a valid kag configurable class")
availables = cls_obj.list_available_with_detail()
seg = " " * 20
deduped_availables = {}
for register_name, cls_info in availables.items():
cls = cls_info["class"]
if cls not in deduped_availables:
deduped_availables[cls] = [register_name]
else:
deduped_availables[cls].append(register_name)
print(f"{bold}{red}{seg}Documentation of {args.cls}{seg}{reset}")
import inspect
print(inspect.getdoc(cls_obj))
print(f"{bold}{red}{seg}Registered subclasses of {args.cls}{seg}{reset}")
visited = set()
for register_name, cls_info in availables.items():
cls = cls_info["class"]
if cls in visited:
continue
visited.add(cls)
print(f"{bold}{blue}[{cls}]{reset}")
register_names = " / ".join([f'"{x}"' for x in deduped_availables[cls]])
print(f"{bold}{green}Register Name:{reset} {register_names}\n")
# print(f"Class Name: {cls_info['class']}")
print(f"{bold}{green}Documentation:{reset}\n{cls_info['doc']}\n")
print(f"{bold}{green}Initializer:{reset}\n{cls_info['constructor']}\n")
required_arguments = []
for item in cls_info["params"]["required_params"]:
required_arguments.append(f" {item}")
if len(required_arguments) == 0:
required_arguments = " No Required Arguments found"
else:
required_arguments = "\n".join(required_arguments)
print(f"{bold}{green}Required Arguments:{reset}\n{required_arguments}\n")
optional_arguments = []
for item in cls_info["params"]["optional_params"]:
optional_arguments.append(f" {item}")
if len(optional_arguments) == 0:
optional_arguments = " No Optional Arguments found"
else:
optional_arguments = "\n".join(optional_arguments)
print(f"{bold}{green}Optional Arguments:{reset}\n{optional_arguments}\n")
print(f"{bold}{green}Sample Useage:{reset}\n {cls_info['sample_useage']}")
# for k, v in cls_info.items():
# print(f"{k}: {v}")
print("\n")
@staticmethod
def handler(args: argparse.Namespace):
if args.list:
ListRegisterInfo.handle_list(args)
else:
ListRegisterInfo.handle_cls(args)

35
kag/bin/kag_cmds.py Normal file
View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import argparse
from kag.bin.base import add_commands
def build_parser():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(
dest="subcommand_name",
title="subcommands",
help="subcommands supported by kag",
)
# add registered commands to parser
cmds = [
"register_info",
]
add_commands(subparsers, cmds)
return parser
def main():
"""entry point of script"""
parser = build_parser()
args = parser.parse_args()
args.func(args)

View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import json
import kag.interface as interface
from kag.common.conf import KAGConstants, init_env
def init_kag_config(project_id: str, host_addr: str):
os.environ[KAGConstants.ENV_KAG_PROJECT_ID] = project_id
os.environ[KAGConstants.ENV_KAG_PROJECT_HOST_ADDR] = host_addr
init_env()
class SPGServerBridge:
def __init__(self):
pass
def run_reader(self, config, input_data):
if isinstance(config, str):
config = json.loads(config)
scanner_config = config["scanner"]
reader_config = config["reader"]
scanner = interface.ScannerABC.from_config(scanner_config)
reader = interface.ReaderABC.from_config(reader_config)
chunks = []
for data in scanner.generate(input_data):
chunks += reader.invoke(data, write_ckpt=False)
return [x.to_dict() for x in chunks]
def run_component(self, component_name, component_config, input_data):
if isinstance(component_config, str):
component_config = json.loads(component_config)
cls = getattr(interface, component_name)
instance = cls.from_config(component_config)
if hasattr(instance.input_types, "from_dict"):
input_data = instance.input_types.from_dict(input_data)
return [x.to_dict() for x in instance.invoke(input_data, write_ckpt=False)]

View File

@ -1,10 +0,0 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -10,13 +10,76 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.external_graph.external_graph import (
DefaultExternalGraphLoader,
)
from kag.builder.component.extractor.schema_free_extractor import SchemaFreeExtractor
from kag.builder.component.extractor.schema_constraint_extractor import (
SchemaConstraintExtractor,
)
from kag.builder.component.aligner.kag_aligner import KAGAligner
from kag.builder.component.aligner.spg_aligner import SPGAligner
from kag.builder.component.postprocessor.kag_postprocessor import KAGPostProcessor
from kag.builder.component.mapping.spg_type_mapping import SPGTypeMapping
from kag.builder.component.mapping.relation_mapping import RelationMapping
from kag.builder.component.mapping.spo_mapping import SPOMapping
from kag.builder.component.scanner.csv_scanner import CSVScanner
from kag.builder.component.scanner.json_scanner import JSONScanner
from kag.builder.component.scanner.yuque_scanner import YuqueScanner
from kag.builder.component.scanner.dataset_scanner import (
MusiqueCorpusScanner,
HotpotqaCorpusScanner,
)
from kag.builder.component.scanner.file_scanner import FileScanner
from kag.builder.component.scanner.directory_scanner import DirectoryScanner
from kag.builder.component.reader.pdf_reader import PDFReader
from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.component.reader.docx_reader import DocxReader
from kag.builder.component.reader.txt_reader import TXTReader
from kag.builder.component.reader.mix_reader import MixReader
from kag.builder.component.reader.dict_reader import DictReader
from kag.builder.component.splitter.length_splitter import LengthSplitter
from kag.builder.component.splitter.pattern_splitter import PatternSplitter
from kag.builder.component.splitter.outline_splitter import OutlineSplitter
from kag.builder.component.splitter.semantic_splitter import SemanticSplitter
from kag.builder.component.vectorizer.batch_vectorizer import BatchVectorizer
from kag.builder.component.writer.kg_writer import KGWriter
__all__ = [
"DefaultExternalGraphLoader",
"SchemaFreeExtractor",
"SchemaConstraintExtractor",
"KAGAligner",
"SPGAligner",
"KAGPostProcessor",
"KGWriter",
"SPGTypeMapping",
"RelationMapping",
"SPOMapping",
"TXTReader",
"PDFReader",
"MarkDownReader",
"DocxReader",
"MixReader",
"DictReader",
"JSONScanner",
"HotpotqaCorpusScanner",
"MusiqueCorpusScanner",
"FileScanner",
"DirectoryScanner",
"YuqueScanner",
"CSVScanner",
"LengthSplitter",
"PatternSplitter",
"OutlineSplitter",
"SemanticSplitter",
"BatchVectorizer",
"KGWriter",
]

View File

@ -1,12 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -13,12 +13,25 @@
from typing import List, Sequence, Dict, Type
from kag.builder.model.sub_graph import SubGraph
from kag.interface.builder import AlignerABC
from kag.interface import AlignerABC
from knext.common.base.runnable import Input, Output
class KAGPostProcessorAligner(AlignerABC):
@AlignerABC.register("kag")
class KAGAligner(AlignerABC):
"""
A class that extends the AlignerABC base class. It is responsible for aligning and merging subgraphs.
This class provides methods to handle the alignment and merging of subgraphs, as well as properties to define the input and output types.
"""
def __init__(self, **kwargs):
"""
Initializes the KAGAligner instance.
Args:
**kwargs: Arbitrary keyword arguments passed to the parent class constructor.
"""
super().__init__(**kwargs)
@property
@ -30,6 +43,16 @@ class KAGPostProcessorAligner(AlignerABC):
return SubGraph
def invoke(self, input: List[SubGraph], **kwargs) -> SubGraph:
"""
Merges a list of subgraphs into a single subgraph.
Args:
input (List[SubGraph]): A list of subgraphs to be merged.
**kwargs: Additional keyword arguments.
Returns:
SubGraph: The merged subgraph containing all nodes and edges from the input subgraphs.
"""
merged_sub_graph = SubGraph(nodes=[], edges=[])
for sub_graph in input:
for node in sub_graph.nodes:
@ -41,9 +64,15 @@ class KAGPostProcessorAligner(AlignerABC):
return merged_sub_graph
def _handle(self, input: Sequence[Dict]) -> Dict:
"""
Handles the input by converting it to the appropriate type, invoking the aligner, and converting the output back to a dictionary.
Args:
input (Sequence[Dict]): A sequence of dictionaries representing subgraphs.
Returns:
Dict: A dictionary representing the merged subgraph.
"""
_input = [self.input_types.from_dict(i) for i in input]
_output = self.invoke(_input)
return _output.to_dict()
def batch(self, inputs: List[Input], **kwargs) -> List[Output]:
pass

View File

@ -12,8 +12,9 @@
from typing import List, Type, Dict
from kag.interface.builder import AlignerABC
from kag.interface import AlignerABC
from knext.schema.client import BASIC_TYPES
from kag.common.conf import KAG_PROJECT_CONF
from kag.builder.model.spg_record import SPGRecord
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
@ -21,10 +22,17 @@ from knext.schema.client import SchemaClient
from knext.schema.model.base import ConstraintTypeEnum, BaseSpgType
class SPGPostProcessorAligner(AlignerABC):
@AlignerABC.register("spg")
class SPGAligner(AlignerABC):
"""
A class that extends the AlignerABC base class. It is responsible for aligning and merging SPG records into subgraphs.
This class provides methods to handle the alignment and merging of SPG records, as well as properties to define the input and output types.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spg_types = SchemaClient(project_id=self.project_id).load()
self.spg_types = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
@property
def input_types(self) -> Type[Input]:
@ -35,6 +43,15 @@ class SPGPostProcessorAligner(AlignerABC):
return SubGraph
def merge(self, spg_records: List[SPGRecord]):
"""
Merges a list of SPG records into a single set of records, combining properties as necessary.
Args:
spg_records (List[SPGRecord]): A list of SPG records to be merged.
Returns:
List[SPGRecord]: A list of merged SPG records.
"""
merged_spg_records = {}
for record in spg_records:
key = f"{record.spg_type_name}#{record.get_property('name', '')}"
@ -75,6 +92,16 @@ class SPGPostProcessorAligner(AlignerABC):
def from_spg_record(
spg_types: Dict[str, BaseSpgType], spg_records: List[SPGRecord]
):
"""
Converts a list of SPG records into a subgraph.
Args:
spg_types (Dict[str, BaseSpgType]): A dictionary mapping SPG type names to their corresponding types.
spg_records (List[SPGRecord]): A list of SPG records to be converted.
Returns:
SubGraph: A subgraph representing the converted SPG records.
"""
sub_graph = SubGraph([], [])
for record in spg_records:
s_id = record.id
@ -107,10 +134,30 @@ class SPGPostProcessorAligner(AlignerABC):
return sub_graph
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Processes a single input and returns a list of outputs.
Args:
input (Input): The input to be processed.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list containing the processed output.
"""
subgraph = SubGraph.from_spg_record(self.spg_types, [input])
return [subgraph]
def batch(self, inputs: List[Input], **kwargs) -> List[Output]:
"""
Processes a batch of inputs and returns a list of outputs.
Args:
inputs (List[Input]): A list of inputs to be processed.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of outputs corresponding to the processed inputs.
"""
merged_records = self.merge(inputs)
subgraph = SubGraph.from_spg_record(self.spg_types, merged_records)
return [subgraph]

View File

@ -1,81 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from abc import ABC
from typing import List, Dict
import logging
from knext.common.base.component import Component
from knext.common.base.runnable import Input, Output
from knext.project.client import ProjectClient
from kag.common.llm.client import LLMClient
class BuilderComponent(Component, ABC):
"""
Abstract base class for all builder component.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.project_id = kwargs.get("project_id",None) or os.getenv("KAG_PROJECT_ID")
self.config = ProjectClient().get_config(self.project_id)
def _init_llm(self) -> LLMClient:
"""
Initializes the Large Language Model (LLM) client.
This method retrieves the LLM configuration from environment variables and the project ID.
It then fetches the project configuration using the project ID and updates the LLM configuration
with any additional settings from the project. Finally, it creates and initializes the LLM client
using the updated configuration.
Args:
None
Returns:
LLMClient
"""
llm_config = eval(os.getenv("KAG_LLM", "{}"))
project_id = self.project_id or os.getenv("KAG_PROJECT_ID")
if project_id:
try:
config = ProjectClient().get_config(project_id)
llm_config.update(config.get("llm", {}))
except:
logging.warning(
f"Failed to get project config for project id: {project_id}"
)
llm = LLMClient.from_config(llm_config)
return llm
@property
def type(self):
"""
Get the type label of the object.
Returns:
str: The type label of the object, fixed as "BUILDER".
"""
return "BUILDER"
def batch(self, inputs: List[Input], **kwargs) -> List[Output]:
results = []
for input in inputs:
results.extend(self.invoke(input, **kwargs))
return results
def _handle(self, input: Dict) -> List[Dict]:
_input = self.input_types.from_dict(input) if isinstance(input, dict) else input
_output = self.invoke(_input)
return [_o.to_dict() for _o in _output if _o]

View File

@ -0,0 +1,212 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import numpy as np
import logging
from typing import List, Union, Dict
from kag.interface import ExternalGraphLoaderABC, MatchConfig
from kag.common.conf import KAG_PROJECT_CONF
from kag.builder.model.sub_graph import Node, Edge, SubGraph
from knext.schema.client import SchemaClient
from knext.search.client import SearchClient
logger = logging.getLogger()
@ExternalGraphLoaderABC.register("base", constructor="from_json_file", as_default=True)
class DefaultExternalGraphLoader(ExternalGraphLoaderABC):
"""
A default implementation of the ExternalGraphLoaderABC interface.
This class is responsible for loading external graph data based on the provided nodes, edges, and match configuration.
"""
def __init__(
self,
nodes: List[Node],
edges: List[Edge],
match_config: MatchConfig,
):
"""
Initializes the DefaultExternalGraphLoader with the given nodes, edges, and match configuration.
Args:
nodes (List[Node]): A list of Node objects representing the nodes in the graph.
edges (List[Edge]): A list of Edge objects representing the edges in the graph.
match_config (MatchConfig): The configuration for matching query str to graph nodes.
"""
super().__init__()
self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
for node in nodes:
if node.label not in self.schema:
raise ValueError(
f"Type of node {node.to_dict()} is beyond the schema definition."
)
for k in node.properties.keys():
if k not in self.schema[node.label]:
raise ValueError(
f"Property of node {node.to_dict()} is beyond the schema definition."
)
self.nodes = nodes
self.edges = edges
self.vocabulary = {}
self.node_labels = set()
for node in self.nodes:
self.vocabulary[node.name] = node
self.node_labels.add(node.label)
import jieba
for word in self.vocabulary.keys():
jieba.add_word(word)
self.match_config = match_config
self._init_search()
def _init_search(self):
self._search_client = SearchClient(
KAG_PROJECT_CONF.host_addr, KAG_PROJECT_CONF.project_id
)
def _group_by_label(self, data: Union[List[Node], List[Edge]]):
groups = {}
for item in data:
label = item.label
if label not in groups:
groups[label] = [item]
else:
groups[label].append(item)
return list(groups.values())
def _group_by_cnt(self, data, n):
return [data[i : i + n] for i in range(0, len(data), n)]
def dump(self, max_num_nodes: int = 4096, max_num_edges: int = 4096):
graphs = []
# process nodes
for item in self._group_by_label(self.nodes):
for grouped_nodes in self._group_by_cnt(item, max_num_nodes):
graphs.append(SubGraph(nodes=grouped_nodes, edges=[]))
# process edges
for item in self._group_by_label(self.edges):
for grouped_edges in self._group_by_cnt(item, max_num_edges):
graphs.append(SubGraph(nodes=[], edges=grouped_edges))
return graphs
def ner(self, content: str):
output = []
import jieba
for word in jieba.cut(content):
if word in self.vocabulary:
output.append(self.vocabulary[word])
return output
def get_allowed_labels(self, labels: List[str] = None):
allowed_labels = []
namespace = KAG_PROJECT_CONF.namespace
if labels is None:
allowed_labels = [f"{namespace}.{x}" for x in self.node_labels]
else:
for label in labels:
# remove namespace
if label.startswith(KAG_PROJECT_CONF.namespace):
label = label.split(".")[1]
if label in self.node_labels:
allowed_labels.append(f"{namespace}.{label}")
return allowed_labels
def search_result_to_node(self, search_result: Dict):
output = []
for label in search_result["__labels__"]:
node = {
"id": search_result["id"],
"name": search_result["name"],
"label": label,
}
output.append(Node.from_dict(node))
return output
def text_match(self, query: str, k: int = 1, labels: List[str] = None):
allowed_labels = self.get_allowed_labels(labels)
text_matched = self._search_client.search_text(query, allowed_labels, topk=k)
return text_matched
def vector_match(
self,
query: Union[List[float], np.ndarray],
k: int = 1,
threshold: float = 0.9,
labels: List[str] = None,
):
allowed_labels = self.get_allowed_labels(labels)
if isinstance(query, np.ndarray):
query = query.tolist()
matched_results = []
for label in allowed_labels:
vector_matched = self._search_client.search_vector(
label=label, property_key="name", query_vector=query, topk=k
)
matched_results.extend(vector_matched)
filtered_results = []
for item in matched_results:
score = item["score"]
if score >= threshold:
filtered_results.append(item)
return filtered_results
def match_entity(self, query: Union[str, List[float], np.ndarray]):
if isinstance(query, str):
return self.text_match(
query, k=self.match_config.k, labels=self.match_config.labels
)
else:
return self.vector_match(
query,
k=self.match_config.k,
labels=self.match_config.labels,
threshold=self.match_config.threshold,
)
@classmethod
def from_json_file(
cls,
node_file_path: str,
edge_file_path: str,
match_config: MatchConfig,
):
"""
Creates an instance of DefaultExternalGraphLoader from JSON files containing node and edge data.
Args:
node_file_path (str): The path to the JSON file containing node data.
edge_file_path (str): The path to the JSON file containing edge data.
match_config (MatchConfig): The configuration for matching query str to graph nodes.
Returns:
DefaultExternalGraphLoader: An instance of DefaultExternalGraphLoader initialized with the data from the JSON files.
"""
nodes = []
for item in json.load(open(node_file_path, "r")):
nodes.append(Node.from_dict(item))
edges = []
for item in json.load(open(edge_file_path, "r")):
edges.append(Edge.from_dict(item))
return cls(nodes=nodes, edges=edges, match_config=match_config)

View File

@ -1,23 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.extractor.kag_extractor import KAGExtractor
from kag.builder.component.extractor.spg_extractor import SPGExtractor
from kag.builder.component.extractor.user_defined_extractor import (
UserDefinedExtractor,
)
__all__ = [
"KAGExtractor",
"SPGExtractor",
"UserDefinedExtractor",
]

View File

@ -0,0 +1,429 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import copy
import logging
from typing import Dict, Type, List
from kag.interface import LLMClient
from tenacity import stop_after_attempt, retry
from kag.interface import ExtractorABC, PromptABC, ExternalGraphLoaderABC
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.utils import processing_phrases, to_camel_case
from kag.builder.model.chunk import Chunk
from kag.builder.model.sub_graph import SubGraph
from kag.builder.prompt.utils import init_prompt_with_fallback
from knext.schema.client import CHUNK_TYPE, BASIC_TYPES
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
logger = logging.getLogger(__name__)
@ExtractorABC.register("schema_constraint")
@ExtractorABC.register("schema_constraint_extractor")
class SchemaConstraintExtractor(ExtractorABC):
"""
Perform knowledge extraction for enforcing schema constraints, including entities, events and their edges.
The types of entities and events, along with their respective attributes, are automatically inherited from the project's schema.
"""
def __init__(
self,
llm: LLMClient,
ner_prompt: PromptABC = None,
std_prompt: PromptABC = None,
relation_prompt: PromptABC = None,
event_prompt: PromptABC = None,
external_graph: ExternalGraphLoaderABC = None,
):
"""
Initializes the SchemaBasedExtractor instance.
Args:
llm (LLMClient): The language model client used for extraction.
ner_prompt (PromptABC, optional): The prompt for named entity recognition. Defaults to None.
std_prompt (PromptABC, optional): The prompt for named entity standardization. Defaults to None.
relation_prompt (PromptABC, optional): The prompt for relation extraction. Defaults to None.
event_prompt (PromptABC, optional): The prompt for event extraction. Defaults to None.
external_graph (ExternalGraphLoaderABC, optional): The external graph loader for additional data. Defaults to None.
"""
super().__init__()
self.llm = llm
self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
self.ner_prompt = ner_prompt
self.std_prompt = std_prompt
self.relation_prompt = relation_prompt
self.event_prompt = event_prompt
biz_scene = KAG_PROJECT_CONF.biz_scene
if self.ner_prompt is None:
self.ner_prompt = init_prompt_with_fallback("ner", biz_scene)
if self.std_prompt is None:
self.std_prompt = init_prompt_with_fallback("std", biz_scene)
self.external_graph = external_graph
@property
def input_types(self) -> Type[Input]:
return Chunk
@property
def output_types(self) -> Type[Output]:
return SubGraph
@retry(stop=stop_after_attempt(3))
def named_entity_recognition(self, passage: str):
"""
Performs named entity recognition on a given text passage.
Args:
passage (str): The text to perform named entity recognition on.
Returns:
The result of the named entity recognition operation.
"""
ner_result = self.llm.invoke({"input": passage}, self.ner_prompt)
if self.external_graph:
extra_ner_result = self.external_graph.ner(passage)
else:
extra_ner_result = []
output = []
dedup = set()
for item in extra_ner_result:
name = item.name
if name not in dedup:
dedup.add(name)
output.append(
{
"name": name,
"category": item.label,
"properties": item.properties,
}
)
for item in ner_result:
name = item.get("name", None)
category = item.get("category", None)
if name is None or category is None:
continue
if not isinstance(name, str):
continue
if name not in dedup:
dedup.add(name)
output.append(item)
return output
@retry(stop=stop_after_attempt(3))
def named_entity_standardization(self, passage: str, entities: List[Dict]):
"""
Performs named entity standardization on a given text passage and entities.
Args:
passage (str): The text passage.
entities (List[Dict]): The list of entities to standardize.
Returns:
The result of the named entity standardization operation.
"""
return self.llm.invoke(
{"input": passage, "named_entities": entities}, self.std_prompt
)
@retry(stop=stop_after_attempt(3))
def relations_extraction(self, passage: str, entities: List[Dict]):
"""
Performs relation extraction on a given text passage and entities.
Args:
passage (str): The text passage.
entities (List[Dict]): The list of entities.
Returns:
The result of the relation extraction operation.
"""
if self.relation_prompt is None:
logger.debug("Relation extraction prompt not configured, skip.")
return []
return self.llm.invoke(
{"input": passage, "entity_list": entities}, self.relation_prompt
)
@retry(stop=stop_after_attempt(3))
def event_extraction(self, passage: str):
"""
Performs event extraction on a given text passage.
Args:
passage (str): The text passage.
Returns:
The result of the event extraction operation.
"""
if self.event_prompt is None:
logger.debug("Event extraction prompt not configured, skip.")
return []
return self.llm.invoke({"input": passage}, self.event_prompt)
def parse_nodes_and_edges(self, entities: List[Dict], category: str = None):
"""
Parses nodes and edges from a list of entities.
Args:
entities (List[Dict]): The list of entities.
Returns:
Tuple[List[Node], List[Edge]]: The parsed nodes and edges.
"""
graph = SubGraph([], [])
entities = copy.deepcopy(entities)
root_nodes = []
for record in entities:
if record is None:
continue
if isinstance(record, str):
record = {"name": record}
s_name = record.get("name", "")
s_label = record.get("category", category)
properties = record.get("properties", {})
# At times, the name and/or label is placed in the properties.
if not s_name:
s_name = properties.pop("name", "")
if not s_label:
s_label = properties.pop("category", "")
if not s_name or not s_label:
continue
s_name = processing_phrases(s_name)
root_nodes.append((s_name, s_label))
tmp_properties = copy.deepcopy(properties)
spg_type = self.schema.get(s_label)
for prop_name, prop_value in properties.items():
if prop_value is None:
tmp_properties.pop(prop_name)
continue
if prop_name in spg_type.properties:
prop_schema = spg_type.properties.get(prop_name)
o_label = prop_schema.object_type_name_en
if o_label not in BASIC_TYPES:
# pop and convert property to node and edge
if not isinstance(prop_value, list):
prop_value = [prop_value]
(
new_root_nodes,
new_nodes,
new_edges,
) = self.parse_nodes_and_edges(prop_value, o_label)
graph.nodes.extend(new_nodes)
graph.edges.extend(new_edges)
# connect current node to property generated nodes
for node in new_root_nodes:
graph.add_edge(
s_id=s_name,
s_label=s_label,
p=prop_name,
o_id=node[0],
o_label=node[1],
)
tmp_properties.pop(prop_name)
record["properties"] = tmp_properties
# NOTE: For property converted to nodes/edges, we keep a copy of the original property values.
# Perhaps it is not necessary?
graph.add_node(id=s_name, name=s_name, label=s_label, properties=properties)
if "official_name" in record:
official_name = processing_phrases(record["official_name"])
if official_name != s_name:
graph.add_node(
id=official_name,
name=official_name,
label=s_label,
properties=properties,
)
graph.add_edge(
s_id=s_name,
s_label=s_label,
p="OfficialName",
o_id=official_name,
o_label=s_label,
)
return root_nodes, graph.nodes, graph.edges
@staticmethod
def add_relations_to_graph(
sub_graph: SubGraph, entities: List[Dict], relations: List[list]
):
"""
Add edges to the subgraph based on a list of relations and entities.
Args:
sub_graph (SubGraph): The subgraph to add edges to.
entities (List[Dict]): A list of entities, for looking up category information.
relations (List[list]): A list of relations, each representing a relationship to be added to the subgraph.
Returns:
The constructed subgraph.
"""
for rel in relations:
if len(rel) != 5:
continue
s_name, s_category, predicate, o_name, o_category = rel
s_name = processing_phrases(s_name)
sub_graph.add_node(s_name, s_name, s_category)
o_name = processing_phrases(o_name)
sub_graph.add_node(o_name, o_name, o_category)
edge_type = to_camel_case(predicate)
if edge_type:
sub_graph.add_edge(s_name, s_category, edge_type, o_name, o_category)
return sub_graph
@staticmethod
def add_chunk_to_graph(sub_graph: SubGraph, chunk: Chunk):
"""
Associates a Chunk object with the subgraph, adding it as a node and connecting it with existing nodes.
Args:
sub_graph (SubGraph): The subgraph to add the chunk information to.
chunk (Chunk): The chunk object containing the text and metadata.
Returns:
The constructed subgraph.
"""
for node in sub_graph.nodes:
sub_graph.add_edge(node.id, node.label, "source", chunk.id, CHUNK_TYPE)
sub_graph.add_node(
id=chunk.id,
name=chunk.name,
label=CHUNK_TYPE,
properties={
"id": chunk.id,
"name": chunk.name,
"content": f"{chunk.name}\n{chunk.content}",
**chunk.kwargs,
},
)
sub_graph.id = chunk.id
return sub_graph
def assemble_subgraph(
self,
chunk: Chunk,
entities: List[Dict],
relations: List[list],
events: List[Dict],
):
"""
Assembles a subgraph from the given chunk, entities, events, and relations.
Args:
chunk (Chunk): The chunk object.
entities (List[Dict]): The list of entities.
events (List[Dict]): The list of events.
Returns:
The constructed subgraph.
"""
graph = SubGraph([], [])
_, entity_nodes, entity_edges = self.parse_nodes_and_edges(entities)
graph.nodes.extend(entity_nodes)
graph.edges.extend(entity_edges)
_, event_nodes, event_edges = self.parse_nodes_and_edges(events)
graph.nodes.extend(event_nodes)
graph.edges.extend(event_edges)
self.add_relations_to_graph(graph, entities, relations)
self.add_chunk_to_graph(graph, chunk)
return graph
def append_official_name(
self, source_entities: List[Dict], entities_with_official_name: List[Dict]
):
"""
Appends official names to entities.
Args:
source_entities (List[Dict]): A list of source entities.
entities_with_official_name (List[Dict]): A list of entities with official names.
"""
tmp_dict = {}
for tmp_entity in entities_with_official_name:
name = tmp_entity["name"]
category = tmp_entity["category"]
official_name = tmp_entity["official_name"]
key = f"{category}{name}"
tmp_dict[key] = official_name
for tmp_entity in source_entities:
name = tmp_entity["name"]
category = tmp_entity["category"]
key = f"{category}{name}"
if key in tmp_dict:
official_name = tmp_dict[key]
tmp_entity["official_name"] = official_name
def postprocess_graph(self, graph):
"""
Postprocesses the graph by merging nodes with the same name and label.
Args:
graph (SubGraph): The graph to postprocess.
Returns:
The postprocessed graph.
"""
try:
all_node_properties = {}
for node in graph.nodes:
id_ = node.id
name = node.name
label = node.label
key = (id_, name, label)
if key not in all_node_properties:
all_node_properties[key] = node.properties
else:
all_node_properties[key].update(node.properties)
new_graph = SubGraph([], [])
for key, node_properties in all_node_properties.items():
id_, name, label = key
new_graph.add_node(
id=id_, name=name, label=label, properties=node_properties
)
new_graph.edges = graph.edges
return new_graph
except:
return graph
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the extractor on the given input.
Args:
input (Input): The input data.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: The list of output results.
"""
title = input.name
passage = title + "\n" + input.content
out = []
entities = self.named_entity_recognition(passage)
events = self.event_extraction(passage)
named_entities = []
for entity in entities:
named_entities.append(
{"name": entity["name"], "category": entity["category"]}
)
relations = self.relations_extraction(passage, named_entities)
std_entities = self.named_entity_standardization(passage, named_entities)
self.append_official_name(entities, std_entities)
subgraph = self.assemble_subgraph(input, entities, relations, events)
out.append(self.postprocess_graph(subgraph))
logger.debug(f"input passage:\n{passage}")
logger.debug(f"output graphs:\n{out}")
return out

View File

@ -11,64 +11,75 @@
# or implied.
import copy
import logging
import os
from typing import Dict, Type, List
from kag.interface import LLMClient
from tenacity import stop_after_attempt, retry
from kag.builder.prompt.spg_prompt import SPG_KGPrompt
from kag.interface.builder import ExtractorABC
from kag.common.base.prompt_op import PromptOp
from knext.schema.client import OTHER_TYPE, CHUNK_TYPE, BASIC_TYPES
from kag.interface import ExtractorABC, PromptABC, ExternalGraphLoaderABC
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.utils import processing_phrases, to_camel_case
from kag.builder.model.chunk import Chunk
from kag.builder.model.sub_graph import SubGraph
from kag.builder.prompt.utils import init_prompt_with_fallback
from knext.schema.client import OTHER_TYPE, CHUNK_TYPE, BASIC_TYPES
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.base import SpgTypeEnum
logger = logging.getLogger(__name__)
class KAGExtractor(ExtractorABC):
@ExtractorABC.register("schema_free")
@ExtractorABC.register("schema_free_extractor")
class SchemaFreeExtractor(ExtractorABC):
"""
A class for extracting knowledge graph subgraphs from text using a large language model (LLM).
Inherits from the Extractor base class.
Attributes:
llm (LLMClient): The large language model client used for text processing.
schema (SchemaClient): The schema client used to load the schema for the project.
ner_prompt (PromptABC): The prompt used for named entity recognition.
std_prompt (PromptABC): The prompt used for named entity standardization.
triple_prompt (PromptABC): The prompt used for triple extraction.
external_graph (ExternalGraphLoaderABC): The external graph loader used for additional NER.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.llm = self._init_llm()
self.prompt_config = self.config.get("prompt", {})
self.biz_scene = self.prompt_config.get("biz_scene") or os.getenv(
"KAG_PROMPT_BIZ_SCENE", "default"
)
self.language = self.prompt_config.get("language") or os.getenv(
"KAG_PROMPT_LANGUAGE", "en"
)
self.schema = SchemaClient(project_id=self.project_id).load()
self.ner_prompt = PromptOp.load(self.biz_scene, "ner")(
language=self.language, project_id=self.project_id
)
self.std_prompt = PromptOp.load(self.biz_scene, "std")(language=self.language)
self.triple_prompt = PromptOp.load(self.biz_scene, "triple")(
language=self.language
)
self.kg_types = []
for type_name, spg_type in self.schema.items():
if type_name in SPG_KGPrompt.ignored_types:
continue
if spg_type.spg_type_enum == SpgTypeEnum.Concept:
continue
properties = list(spg_type.properties.keys())
for p in properties:
if p not in SPG_KGPrompt.ignored_properties:
self.kg_types.append(type_name)
break
if self.kg_types:
self.kg_prompt = SPG_KGPrompt(
self.kg_types, language=self.language, project_id=self.project_id
)
def __init__(
self,
llm: LLMClient,
ner_prompt: PromptABC = None,
std_prompt: PromptABC = None,
triple_prompt: PromptABC = None,
external_graph: ExternalGraphLoaderABC = None,
):
"""
Initializes the KAGExtractor with the specified parameters.
Args:
llm (LLMClient): The large language model client.
ner_prompt (PromptABC, optional): The prompt for named entity recognition. Defaults to None.
std_prompt (PromptABC, optional): The prompt for named entity standardization. Defaults to None.
triple_prompt (PromptABC, optional): The prompt for triple extraction. Defaults to None.
external_graph (ExternalGraphLoaderABC, optional): The external graph loader. Defaults to None.
"""
super().__init__()
self.llm = llm
self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
self.ner_prompt = ner_prompt
self.std_prompt = std_prompt
self.triple_prompt = triple_prompt
biz_scene = KAG_PROJECT_CONF.biz_scene
if self.ner_prompt is None:
self.ner_prompt = init_prompt_with_fallback("ner", biz_scene)
if self.std_prompt is None:
self.std_prompt = init_prompt_with_fallback("std", biz_scene)
if self.triple_prompt is None:
self.triple_prompt = init_prompt_with_fallback("triple", biz_scene)
self.external_graph = external_graph
@property
def input_types(self) -> Type[Input]:
@ -87,12 +98,34 @@ class KAGExtractor(ExtractorABC):
Returns:
The result of the named entity recognition operation.
"""
if self.kg_types:
kg_result = self.llm.invoke({"input": passage}, self.kg_prompt)
else:
kg_result = []
ner_result = self.llm.invoke({"input": passage}, self.ner_prompt)
return kg_result + ner_result
if self.external_graph:
extra_ner_result = self.external_graph.ner(passage)
else:
extra_ner_result = []
output = []
dedup = set()
for item in extra_ner_result:
name = item.name
label = item.label
description = item.properties.get("desc", "")
semantic_type = item.properties.get("semanticType", label)
if name not in dedup:
dedup.add(name)
output.append(
{
"name": name,
"type": semantic_type,
"category": label,
"description": description,
}
)
for item in ner_result:
name = item.get("name", None)
if name and name not in dedup:
dedup.add(name)
output.append(item)
return output
@retry(stop=stop_after_attempt(3))
def named_entity_standardization(self, passage: str, entities: List[Dict]):
@ -125,20 +158,26 @@ class KAGExtractor(ExtractorABC):
)
def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
"""
Assembles a subgraph using SPG records.
Args:
entities (List[Dict]): A list of entities to be used for subgraph assembly.
Returns:
The assembled subgraph and the updated list of entities.
"""
sub_graph = SubGraph([], [])
for record in entities:
s_name = record.get("entity", "")
s_name = record.get("name", "")
s_label = record.get("category", "")
properties = record.get("properties", {})
tmp_properties = copy.deepcopy(properties)
spg_type = self.schema.get(s_label)
if not spg_type:
continue
for prop_name, prop_value in properties.items():
if prop_value == "NAN":
tmp_properties.pop(prop_name)
continue
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
@ -173,11 +212,14 @@ class KAGExtractor(ExtractorABC):
sub_graph (SubGraph): The subgraph to add edges to.
entities (List[Dict]): A list of entities, for looking up category information.
triples (List[list]): A list of triples, each representing a relationship to be added to the subgraph.
Returns:
The constructed subgraph.
"""
def get_category(entities_data, entity_name):
for entity in entities_data:
if entity["entity"] == entity_name:
if entity["name"] == entity_name:
return entity["category"]
return None
@ -194,7 +236,6 @@ class KAGExtractor(ExtractorABC):
if o_category is None:
o_category = OTHER_TYPE
sub_graph.add_node(tri[2], tri[2], o_category)
edge_type = to_camel_case(tri[1])
if edge_type:
sub_graph.add_edge(tri[0], s_category, edge_type, tri[2], o_category)
@ -208,6 +249,8 @@ class KAGExtractor(ExtractorABC):
Args:
sub_graph (SubGraph): The subgraph to add the chunk information to.
chunk (Chunk): The chunk object containing the text and metadata.
Returns:
The constructed subgraph.
"""
for node in sub_graph.nodes:
sub_graph.add_edge(node.id, node.label, "source", chunk.id, CHUNK_TYPE)
@ -240,7 +283,7 @@ class KAGExtractor(ExtractorABC):
entities (List[Dict]): A list of entities identified in the chunk.
triples (List[list]): A list of triples representing relationships between entities.
Returns:
SubGraph: The constructed subgraph.
The constructed subgraph.
"""
self.assemble_sub_graph_with_entities(sub_graph, entities)
self.assemble_sub_graph_with_triples(sub_graph, entities, triples)
@ -259,7 +302,7 @@ class KAGExtractor(ExtractorABC):
"""
for ent in entities:
name = processing_phrases(ent["entity"])
name = processing_phrases(ent["name"])
sub_graph.add_node(
name,
name,
@ -302,26 +345,31 @@ class KAGExtractor(ExtractorABC):
source_entities (List[Dict]): A list of source entities.
entities_with_official_name (List[Dict]): A list of entities with official names.
"""
tmp_dict = {}
for tmp_entity in entities_with_official_name:
name = tmp_entity["entity"]
category = tmp_entity["category"]
official_name = tmp_entity["official_name"]
key = f"{category}{name}"
tmp_dict[key] = official_name
try:
tmp_dict = {}
for tmp_entity in entities_with_official_name:
if "name" in tmp_entity:
name = tmp_entity["name"]
elif "entity" in tmp_entity:
name = tmp_entity["entity"]
else:
continue
category = tmp_entity["category"]
official_name = tmp_entity["official_name"]
key = f"{category}{name}"
tmp_dict[key] = official_name
for tmp_entity in source_entities:
name = tmp_entity["entity"]
category = tmp_entity["category"]
key = f"{category}{name}"
if key in tmp_dict:
official_name = tmp_dict[key]
tmp_entity["official_name"] = official_name
for tmp_entity in source_entities:
name = tmp_entity["name"]
category = tmp_entity["category"]
key = f"{category}{name}"
if key in tmp_dict:
official_name = tmp_dict[key]
tmp_entity["official_name"] = official_name
except Exception as e:
logger.warn(f"failed to process official name, info: {e}")
def quoteStr(self, input: str) -> str:
return f"""{input}"""
def invoke(self, input: Input, **kwargs) -> List[Output]:
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the semantic extractor to process input data.
@ -332,24 +380,19 @@ class KAGExtractor(ExtractorABC):
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""
title = input.name
passage = self.quoteStr(title + "\n" + input.content)
try:
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
filtered_entities = [
{k: v for k, v in ent.items() if k in ["entity", "category"]}
for ent in entities
]
triples = self.triples_extraction(passage, filtered_entities)
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
self.assemble_sub_graph(sub_graph, input, entities, triples)
return [sub_graph]
except Exception as e:
import traceback
traceback.print_exc()
logger.info(e)
return []
passage = title + "\n" + input.content
out = []
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
filtered_entities = [
{k: v for k, v in ent.items() if k in ["name", "category"]}
for ent in entities
]
triples = self.triples_extraction(passage, filtered_entities)
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
self.assemble_sub_graph(sub_graph, input, entities, triples)
out.append(sub_graph)
return out

View File

@ -1,116 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import copy
import logging
from typing import List, Dict
from tenacity import retry, stop_after_attempt
from kag.builder.component.extractor import KAGExtractor
from kag.builder.model.sub_graph import SubGraph
from kag.builder.prompt.spg_prompt import SPG_KGPrompt
from kag.common.base.prompt_op import PromptOp
from knext.common.base.runnable import Input, Output
from knext.schema.client import BASIC_TYPES
logger = logging.getLogger(__name__)
class SPGExtractor(KAGExtractor):
"""
A Builder Component that extracting structured data from long texts by invoking large language model.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spg_ner_types, self.kag_ner_types = [], []
for type_name, spg_type in self.schema.items():
properties = list(spg_type.properties.keys())
for p in properties:
if p not in SPG_KGPrompt.ignored_properties:
self.spg_ner_types.append(type_name)
continue
self.kag_ner_types.append(type_name)
self.kag_ner_prompt = PromptOp.load(self.biz_scene, "ner")(language=self.language, project_id=self.project_id)
self.spg_ner_prompt = SPG_KGPrompt(self.spg_ner_types, self.language, project_id=self.project_id)
@retry(stop=stop_after_attempt(3))
def named_entity_recognition(self, passage: str):
"""
Performs named entity recognition on a given text passage.
Args:
passage (str): The text to perform named entity recognition on.
Returns:
The result of the named entity recognition operation.
"""
spg_ner_result = self.llm.batch({"input": passage}, self.spg_ner_prompt)
kag_ner_result = self.llm.invoke({"input": passage}, self.kag_ner_prompt)
return spg_ner_result + kag_ner_result
def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
sub_graph = SubGraph([], [])
for record in entities:
s_name = record.get("entity", "")
s_label = record.get("category", "")
properties = record.get("properties", {})
tmp_properties = copy.deepcopy(properties)
spg_type = self.schema.get(s_label)
for prop_name, prop_value in properties.items():
if prop_value == "NAN":
tmp_properties.pop(prop_name)
continue
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name_en
if o_label not in BASIC_TYPES:
if isinstance(prop_value, str):
prop_value = [prop_value]
for o_name in prop_value:
sub_graph.add_node(id=o_name, name=o_name, label=o_label)
sub_graph.add_edge(s_id=s_name, s_label=s_label, p=prop_name, o_id=o_name, o_label=o_label)
tmp_properties.pop(prop_name)
record["properties"] = tmp_properties
sub_graph.add_node(id=s_name, name=s_name, label=s_label, properties=properties)
return sub_graph, entities
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the semantic extractor to process input data.
Args:
input (Input): Input data containing name and content.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""
title = input.name
passage = title + "\n" + input.content
try:
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
filtered_entities = [{k: v for k, v in ent.items() if k in ["entity", "category"]} for ent in entities]
triples = self.triples_extraction(passage, filtered_entities)
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
self.assemble_sub_graph(sub_graph, input, entities, triples)
return [sub_graph]
except Exception as e:
import traceback
traceback.print_exc()
logger.info(e)
return []

View File

@ -1,21 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.mapping.spg_type_mapping import SPGTypeMapping
from kag.builder.component.mapping.relation_mapping import RelationMapping
from kag.builder.component.mapping.spo_mapping import SPOMapping
__all__ = [
"SPGTypeMapping",
"RelationMapping",
"SPOMapping",
]

View File

@ -10,40 +10,46 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from collections import defaultdict
from typing import Dict, List
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.schema_helper import (
SPGTypeName,
RelationName,
)
from kag.interface.builder.mapping_abc import MappingABC
from kag.common.conf import KAG_PROJECT_CONF
from kag.interface import MappingABC
@MappingABC.register("relation")
class RelationMapping(MappingABC):
"""
A class that handles relation mappings by assembling subgraphs based on given subject, predicate, and object names.
This class extends the Mapping class.
Args:
subject_name (SPGTypeName): The name of the subject type.
predicate_name (RelationName): The name of the predicate.
object_name (SPGTypeName): The name of the object type.
A class that extends the MappingABC class.
It handles relation mappings by assembling subgraphs based on given subject, predicate, and object names.
"""
def __init__(
self,
subject_name: SPGTypeName,
predicate_name: RelationName,
object_name: SPGTypeName,
**kwargs
subject_name: str,
predicate_name: str,
object_name: str,
src_id_field: str = None,
dst_id_field: str = None,
property_mapping: dict = {},
**kwargs,
):
"""
Initializes the RelationMapping instance.
Args:
subject_name (str): The name of the subject type.
predicate_name (str): The name of the predicate type.
object_name (str): The name of the object type.
src_id_field (str, optional): The field name for the source ID. Defaults to None.
dst_id_field (str, optional): The field name for the destination ID. Defaults to None.
property_mapping (dict, optional): A dictionary mapping properties. Defaults to {}.
**kwargs: Additional keyword arguments passed to the parent class constructor.
"""
super().__init__(**kwargs)
schema = SchemaClient(project_id=self.project_id).load()
schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
assert subject_name in schema, f"{subject_name} is not a valid SPG type name"
assert object_name in schema, f"{object_name} is not a valid SPG type name"
self.subject_type = schema.get(subject_name)
@ -54,10 +60,9 @@ class RelationMapping(MappingABC):
), f"{predicate_name} is not a valid SPG property/relation name"
self.predicate_name = predicate_name
self.src_id_field = None
self.dst_id_field = None
self.property_mapping: Dict = defaultdict(list)
self.linking_strategies: Dict = dict()
self.src_id_field = src_id_field
self.dst_id_field = dst_id_field
self.property_mapping = property_mapping
def add_src_id_mapping(self, source_name: str):
"""
@ -96,7 +101,11 @@ class RelationMapping(MappingABC):
Returns:
self
"""
self.property_mapping[target_name].append(source_name)
if target_name in self.property_mapping:
self.property_mapping[target_name].append(source_name)
else:
self.property_mapping[target_name] = [source_name]
return self
@property

View File

@ -15,33 +15,31 @@ from typing import Dict, List, Callable
import pandas
from knext.schema.client import BASIC_TYPES
from kag.builder.model.sub_graph import SubGraph, Node
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.base import SpgTypeEnum
from knext.schema.model.schema_helper import (
SPGTypeName,
PropertyName,
)
from kag.common.conf import KAG_PROJECT_CONF
from kag.interface.builder.mapping_abc import MappingABC
FuseFunc = Callable[[SubGraph], List[SubGraph]]
LinkFunc = Callable[[str, Node], List[Node]]
from kag.common.registry import Functor
@MappingABC.register("spg")
@MappingABC.register("spg_mapping")
class SPGTypeMapping(MappingABC):
"""
A class for mapping SPG (Simple Property Graph) types and handling their properties and strategies.
A class for mapping SPG(Semantic-enhanced Programmable Graph) types and handling their properties and strategies.
Attributes:
spg_type_name (SPGTypeName): The name of the SPG type.
fuse_op (FuseOpABC, optional): The user-defined fuse operator. Defaults to None.
"""
def __init__(self, spg_type_name: SPGTypeName, fuse_func: FuseFunc = None, **kwargs):
super().__init__(**kwargs)
self.schema = SchemaClient(project_id=self.project_id).load()
def __init__(self, spg_type_name: str, fuse_func: Functor = None):
self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
assert (
spg_type_name in self.schema
), f"SPG type [{spg_type_name}] does not exist."
@ -55,7 +53,7 @@ class SPGTypeMapping(MappingABC):
self,
source_name: str,
target_name: PropertyName,
link_func: LinkFunc = None,
link_func: Callable = None,
):
"""
Adds a property mapping from a source name to a target name within the SPG type.

View File

@ -10,7 +10,6 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from collections import defaultdict
from typing import List, Type, Dict
from kag.interface.builder.mapping_abc import MappingABC
@ -19,17 +18,44 @@ from knext.common.base.runnable import Input, Output
from knext.schema.client import OTHER_TYPE
@MappingABC.register("spo")
@MappingABC.register("spo_mapping")
class SPOMapping(MappingABC):
"""
A class that extends the MappingABC base class.
It is responsible for mapping structured dictionaries to a list of SubGraphs.
"""
def __init__(self):
def __init__(
self,
s_type_col: str = None,
s_id_col: str = None,
p_type_col: str = None,
o_type_col: str = None,
o_id_col: str = None,
sub_property_col: str = None,
sub_property_mapping: dict = {},
):
"""
Initializes the SPOMapping instance.
Args:
s_type_col (str, optional): The column name for the subject type. Defaults to None.
s_id_col (str, optional): The column name for the subject ID. Defaults to None.
p_type_col (str, optional): The column name for the predicate type. Defaults to None.
o_type_col (str, optional): The column name for the object type. Defaults to None.
o_id_col (str, optional): The column name for the object ID. Defaults to None.
sub_property_col (str, optional): The column name for sub-properties. Defaults to None.
sub_property_mapping (dict, optional): A dictionary mapping sub-properties. Defaults to {}.
"""
super().__init__()
self.s_type_col = None
self.s_id_col = None
self.p_type_col = None
self.o_type_col = None
self.o_id_col = None
self.sub_property_mapping = defaultdict(list)
self.sub_property_col = None
self.s_type_col = s_type_col
self.s_id_col = s_id_col
self.p_type_col = p_type_col
self.o_type_col = o_type_col
self.o_id_col = o_id_col
self.sub_property_col = sub_property_col
self.sub_property_mapping = sub_property_mapping
@property
def input_types(self) -> Type[Input]:
@ -39,7 +65,27 @@ class SPOMapping(MappingABC):
def output_types(self) -> Type[Output]:
return SubGraph
def add_field_mappings(self, s_id_col: str, p_type_col: str, o_id_col: str, s_type_col: str = None, o_type_col: str = None):
def add_field_mappings(
self,
s_id_col: str,
p_type_col: str,
o_id_col: str,
s_type_col: str = None,
o_type_col: str = None,
):
"""
Adds field mappings for the subject, predicate, and object types and IDs.
Args:
s_id_col (str): The column name for the subject ID.
p_type_col (str): The column name for the predicate type.
o_id_col (str): The column name for the object ID.
s_type_col (str, optional): The column name for the subject type. Defaults to None.
o_type_col (str, optional): The column name for the object type. Defaults to None.
Returns:
self
"""
self.s_type_col = s_type_col
self.s_id_col = s_id_col
self.p_type_col = p_type_col
@ -63,7 +109,10 @@ class SPOMapping(MappingABC):
if not target_name:
self.sub_property_col = source_name
else:
self.sub_property_mapping[target_name].append(source_name)
if target_name in self.sub_property_mapping:
self.sub_property_mapping[target_name].append(source_name)
else:
self.sub_property_mapping[target_name] = [source_name]
return self
def assemble_sub_graph(self, record: Dict[str, str]):
@ -86,14 +135,21 @@ class SPOMapping(MappingABC):
sub_graph.add_node(id=o_id, name=o_id, label=o_type)
sub_properties = {}
if self.sub_property_col:
sub_properties = json.loads(record.get(self.sub_property_col, '{}'))
sub_properties = json.loads(record.get(self.sub_property_col, "{}"))
sub_properties = {k: str(v) for k, v in sub_properties.items()}
else:
for target_name, source_names in self.sub_property_mapping.items():
for source_name in source_names:
value = record.get(source_name)
sub_properties[target_name] = value
sub_graph.add_edge(s_id=s_id, s_label=s_type, p=p, o_id=o_id, o_label=o_type, properties=sub_properties)
sub_graph.add_edge(
s_id=s_id,
s_label=s_type,
p=p,
o_id=o_id,
o_label=o_type,
properties=sub_properties,
)
return sub_graph
def invoke(self, input: Input, **kwargs) -> List[Output]:
@ -105,7 +161,7 @@ class SPOMapping(MappingABC):
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of resulting sub-graphs.
List[Output]: A list of resulting subgraphs.
"""
record: Dict[str, str] = input
sub_graph = self.assemble_sub_graph(record)

View File

@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
from typing import List
from tenacity import stop_after_attempt, retry
from kag.interface import PostProcessorABC
from kag.interface import ExternalGraphLoaderABC
from kag.builder.model.sub_graph import SubGraph
from kag.common.conf import KAGConstants, KAG_PROJECT_CONF
from kag.common.utils import get_vector_field_name
from knext.search.client import SearchClient
from knext.schema.client import SchemaClient, OTHER_TYPE
logger = logging.getLogger()
@PostProcessorABC.register("base", as_default=True)
@PostProcessorABC.register("kag_post_processor")
class KAGPostProcessor(PostProcessorABC):
"""
A class that extends the PostProcessorABC base class.
It provides methods to handle various post-processing tasks on subgraphs
including filtering, entity linking based on similarity, and linking based on an external graph.
"""
def __init__(
self,
similarity_threshold: float = 0.9,
external_graph: ExternalGraphLoaderABC = None,
):
"""
Initializes the KAGPostProcessor instance.
Args:
similarity_threshold (float, optional): The similarity threshold for entity linking. Defaults to 0.9.
external_graph (ExternalGraphLoaderABC, optional): An instance of ExternalGraphLoaderABC for external graph-based linking. Defaults to None.
"""
super().__init__()
self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
self.similarity_threshold = similarity_threshold
self.external_graph = external_graph
self._init_search()
def format_label(self, label: str):
"""
Formats the label by adding the project namespace if it is not already present.
Args:
label (str): The label to be formatted.
Returns:
str: The formatted label.
"""
namespace = KAG_PROJECT_CONF.namespace
if label.split(".")[0] == namespace:
return label
return f"{namespace}.{label}"
def _init_search(self):
"""
Initializes the search client for entity linking.
"""
self._search_client = SearchClient(
KAG_PROJECT_CONF.host_addr, KAG_PROJECT_CONF.project_id
)
def filter_invalid_data(self, graph: SubGraph):
"""
Filters out invalid nodes and edges from the subgraph.
Args:
graph (SubGraph): The subgraph to be filtered.
Returns:
SubGraph: The filtered subgraph.
"""
valid_nodes = []
valid_edges = []
for node in graph.nodes:
if not node.id or not node.label:
continue
if node.label not in self.schema:
node.label = self.format_label(OTHER_TYPE)
# for k in node.properties.keys():
# if k not in self.schema[node.label]:
# continue
valid_nodes.append(node)
for edge in graph.edges:
if edge.label:
valid_edges.append(edge)
return SubGraph(nodes=valid_nodes, edges=valid_edges)
@retry(stop=stop_after_attempt(3))
def _entity_link(
self, graph: SubGraph, property_key: str = "name", labels: List[str] = None
):
"""
Performs entity linking based on the given property key and labels.
Args:
graph (SubGraph): The subgraph to perform entity linking on.
property_key (str, optional): The property key to use for linking. Defaults to "name".
labels (List[str], optional): The labels to consider for linking. Defaults to None.
"""
vector_field_name = get_vector_field_name(property_key)
for node in graph.nodes:
if labels is None:
link_labels = [self.format_label(node.label)]
else:
link_labels = [self.format_label(x) for x in labels]
vector = node.properties.get(vector_field_name)
if vector:
all_similar_nodes = []
for label in link_labels:
similar_nodes = self._search_client.search_vector(
label=label,
property_key=property_key,
query_vector=[float(x) for x in vector],
topk=1,
params={},
)
all_similar_nodes.extend(similar_nodes)
for item in all_similar_nodes:
score = item["score"]
if (
score >= self.similarity_threshold
and node.id != item["node"]["id"]
):
graph.add_edge(
node.id,
node.label,
KAGConstants.KAG_SIMILAR_EDGE_NAME,
item["node"]["id"],
item["node"]["__labels__"][0],
)
def similarity_based_link(self, graph: SubGraph, property_key: str = "name"):
"""
Performs entity linking based on similarity.
Args:
graph (SubGraph): The subgraph to perform entity linking on.
property_key (str, optional): The property key to use for linking. Defaults to "name".
"""
self._entity_link(graph, property_key, None)
def external_graph_based_link(self, graph: SubGraph, property_key: str = "name"):
"""
Performs entity linking based on the user provided external graph.
Args:
graph (SubGraph): The subgraph to perform entity linking on.
property_key (str, optional): The property key to use for linking. Defaults to "name".
"""
if not self.external_graph:
return
labels = self.external_graph.get_allowed_labels()
self._entity_link(graph, property_key, labels)
def _invoke(self, input, **kwargs):
"""
Invokes the post-processing pipeline on the input subgraph.
Args:
input: The input subgraph to be processed.
Returns:
List[SubGraph]: A list containing the processed subgraph.
"""
origin_num_nodes = len(input.nodes)
origin_num_edges = len(input.edges)
new_graph = self.filter_invalid_data(input)
self.similarity_based_link(new_graph)
self.external_graph_based_link(new_graph)
new_num_nodes = len(new_graph.nodes)
new_num_edges = len(new_graph.edges)
logger.debug(
f"origin: {origin_num_nodes}/{origin_num_edges}, processed: {new_num_nodes}/{new_num_edges}"
)
return [new_graph]

View File

@ -1,33 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.reader.csv_reader import CSVReader
from kag.builder.component.reader.pdf_reader import PDFReader
from kag.builder.component.reader.json_reader import JSONReader
from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.component.reader.docx_reader import DocxReader
from kag.builder.component.reader.txt_reader import TXTReader
from kag.builder.component.reader.dataset_reader import HotpotqaCorpusReader, TwowikiCorpusReader, MusiqueCorpusReader
from kag.builder.component.reader.yuque_reader import YuqueReader
__all__ = [
"TXTReader",
"PDFReader",
"MarkDownReader",
"JSONReader",
"HotpotqaCorpusReader",
"MusiqueCorpusReader",
"TwowikiCorpusReader",
"YuqueReader",
"CSVReader",
"DocxReader",
]

View File

@ -1,89 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import List, Type, Dict
import pandas as pd
from kag.builder.model.chunk import Chunk
from kag.interface.builder.reader_abc import SourceReaderABC
from knext.common.base.runnable import Input, Output
class CSVReader(SourceReaderABC):
"""
A class for reading CSV files, inheriting from `SourceReader`.
Supports converting CSV data into either a list of dictionaries or a list of Chunk objects.
Args:
output_type (Output): Specifies the output type, which can be "Dict" or "Chunk".
**kwargs: Additional keyword arguments passed to the parent class constructor.
"""
def __init__(self, output_type="Chunk", **kwargs):
super().__init__(**kwargs)
if output_type == "Dict":
self.output_types = Dict[str, str]
else:
self.output_types = Chunk
self.id_col = kwargs.get("id_col", "id")
self.name_col = kwargs.get("name_col", "name")
self.content_col = kwargs.get("content_col", "content")
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return self._output_types
@output_types.setter
def output_types(self, output_types):
self._output_types = output_types
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Reads a CSV file and converts the data format based on the output type.
Args:
input (Input): Input parameter, expected to be a string representing the path to the CSV file.
**kwargs: Additional keyword arguments, which may include `id_column`, `name_column`, `content_column`, etc.
Returns:
List[Output]:
- If `output_types` is `Chunk`, returns a list of Chunk objects.
- If `output_types` is `Dict`, returns a list of dictionaries.
"""
try:
data = pd.read_csv(input)
data = data.astype(str)
except Exception as e:
raise IOError(f"Failed to read the file: {e}")
if self.output_types == Chunk:
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
for idx, row in enumerate(data.to_dict(orient="records")):
kwargs = {k: v for k, v in row.items() if k not in [self.id_col, self.name_col, self.content_col]}
chunks.append(
Chunk(
id=row.get(self.id_col) or Chunk.generate_hash_id(f"{input}#{idx}"),
name=row.get(self.name_col) or f"{basename}#{idx}",
content=row[self.content_col],
**kwargs
)
)
return chunks
else:
return data.to_dict(orient="records")

View File

@ -1,97 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import os
from typing import List, Type
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
class HotpotqaCorpusReader(SourceReaderABC):
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Chunk
def invoke(self, input: str, **kwargs) -> List[Output]:
if os.path.exists(str(input)):
with open(input, "r") as f:
corpus = json.load(f)
else:
corpus = json.loads(input)
chunks = []
for item_key, item_value in corpus.items():
chunk = Chunk(
id=item_key,
name=item_key,
content="\n".join(item_value),
)
chunks.append(chunk)
return chunks
class MusiqueCorpusReader(SourceReaderABC):
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Chunk
def get_basename(self, file_name: str):
base, ext = os.path.splitext(os.path.basename(file_name))
return base
def invoke(self, input: str, **kwargs) -> List[Output]:
id_column = kwargs.get("id_column", "title")
name_column = kwargs.get("name_column", "title")
content_column = kwargs.get("content_column", "text")
if os.path.exists(str(input)):
with open(input, "r") as f:
corpusList = json.load(f)
else:
corpusList = input
chunks = []
for item in corpusList:
chunk = Chunk(
id=item[id_column],
name=item[name_column],
content=item[content_column],
)
chunks.append(chunk)
return chunks
class TwowikiCorpusReader(MusiqueCorpusReader):
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Chunk

View File

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Dict, List
from kag.interface import ReaderABC
from knext.common.base.runnable import Output, Input
from kag.builder.model.chunk import Chunk
@ReaderABC.register("dict")
@ReaderABC.register("dict_reader")
class DictReader(ReaderABC):
"""
A class for reading dictionaries into Chunk objects.
This class inherits from ReaderABC and provides the functionality to convert dictionary inputs
into a list of Chunk objects.
Attributes:
id_col (str): The key in the input dictionary that corresponds to the chunk's ID.
name_col (str): The key in the input dictionary that corresponds to the chunk's name.
content_col (str): The key in the input dictionary that corresponds to the chunk's content.
"""
def __init__(
self, id_col: str = "id", name_col: str = "name", content_col: str = "content"
):
"""
Initializes the DictReader with the specified column names.
Args:
id_col (str): The key in the input dictionary that corresponds to the chunk's ID. Defaults to "id".
name_col (str): The key in the input dictionary that corresponds to the chunk's name. Defaults to "name".
content_col (str): The key in the input dictionary that corresponds to the chunk's content. Defaults to "content".
"""
super().__init__()
self.id_col = id_col
self.name_col = name_col
self.content_col = content_col
@property
def input_types(self) -> Input:
return Dict
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Converts the input dictionary into a list of Chunk objects.
Args:
input (Input): The input dictionary containing the data to be parsed.
**kwargs: Additional keyword arguments, currently unused but kept for potential future expansion.
Returns:
List[Output]: A list containing a single Chunk object created from the input dictionary.
"""
chunk_id = input.get(self.id_col)
chunk_name = input.get(self.name_col)
chunk_content = input.get(self.content_col)
if self.id_col in input:
input.pop(self.id_col)
if self.name_col in input:
input.pop(self.name_col)
if self.content_col in input:
input.pop(self.content_col)
return [Chunk(id=chunk_id, name=chunk_name, content=chunk_content, **input)]

View File

@ -11,17 +11,17 @@
# or implied.
import os
from typing import List, Type,Union
from typing import List, Union
from docx import Document
from kag.builder.component.reader import MarkDownReader
from kag.interface import LLMClient
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from kag.interface import ReaderABC
from kag.builder.prompt.outline_prompt import OutlinePrompt
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.utils import generate_hash_id
from knext.common.base.runnable import Input, Output
from kag.common.llm.client import LLMClient
from kag.builder.prompt.outline_prompt import OutlinePrompt
def split_txt(content):
from modelscope.outputs import OutputKeys
@ -30,40 +30,49 @@ def split_txt(content):
p = pipeline(
task=Tasks.document_segmentation,
model='damo/nlp_bert_document-segmentation_chinese-base')
model="damo/nlp_bert_document-segmentation_chinese-base",
)
result = p(documents=content)
result = result[OutputKeys.TEXT]
res = [r for r in result.split('\n\t') if len(r) > 0]
res = [r for r in result.split("\n\t") if len(r) > 0]
return res
class DocxReader(SourceReaderABC):
@ReaderABC.register("docx")
@ReaderABC.register("docx_reader")
class DocxReader(ReaderABC):
"""
A class for reading Docx files, inheriting from SourceReader.
This class is specifically designed to extract text content from Docx files and generate Chunk objects based on the extracted content.
A class for reading Docx files into Chunk objects.
This class inherits from ReaderABC and provides the functionality to process Docx files,
extract their text content, and convert it into a list of Chunk objects.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.split_level = kwargs.get("split_level", 3)
self.split_using_outline = kwargs.get("split_using_outline", True)
self.outline_flag = True
self.llm = self._init_llm()
language = os.getenv("KAG_PROMPT_LANGUAGE", "zh")
self.prompt = OutlinePrompt(language)
@property
def input_types(self) -> Type[Input]:
return str
def __init__(self, llm: LLMClient = None):
"""
Initializes the DocxReader with an optional LLMClient instance.
@property
def output_types(self) -> Type[Output]:
return Chunk
def outline_chunk(self, chunk: Union[Chunk, List[Chunk]],basename) -> List[Chunk]:
Args:
llm (LLMClient): An optional LLMClient instance used for generating outlines. Defaults to None.
"""
super().__init__()
self.llm = llm
self.prompt = OutlinePrompt(KAG_PROJECT_CONF.language)
def outline_chunk(self, chunk: Union[Chunk, List[Chunk]], basename) -> List[Chunk]:
"""
Generates outlines for the given chunk(s) and separates the content based on these outlines.
Args:
chunk (Union[Chunk, List[Chunk]]): A single Chunk object or a list of Chunk objects.
basename: The base name used for generating chunk IDs and names.
Returns:
List[Chunk]: A list of Chunk objects separated by the generated outlines.
"""
if isinstance(chunk, Chunk):
chunk = [chunk]
outlines = []
@ -71,20 +80,35 @@ class DocxReader(SourceReaderABC):
outline = self.llm.invoke({"input": c.content}, self.prompt)
outlines.extend(outline)
content = "\n".join([c.content for c in chunk])
chunks = self.sep_by_outline(content, outlines,basename)
chunks = self.sep_by_outline(content, outlines, basename)
return chunks
def sep_by_outline(self,content,outlines,basename):
def sep_by_outline(self, content, outlines, basename):
"""
Separates the content based on the provided outlines.
Args:
content (str): The content to be separated.
outlines (List[str]): A list of outlines used to separate the content.
basename: The base name used for generating chunk IDs and names.
Returns:
List[Chunk]: A list of Chunk objects separated by the provided outlines.
"""
position_check = []
for outline in outlines:
start = content.find(outline)
position_check.append((outline,start))
position_check.append((outline, start))
chunks = []
for idx,pc in enumerate(position_check):
for idx, pc in enumerate(position_check):
chunk = Chunk(
id = Chunk.generate_hash_id(f"{basename}#{pc[0]}"),
id=generate_hash_id(f"{basename}#{pc[0]}"),
name=f"{basename}#{pc[0]}",
content=content[pc[1]:position_check[idx+1][1] if idx+1 < len(position_check) else len(position_check)],
content=content[
pc[1] : position_check[idx + 1][1]
if idx + 1 < len(position_check)
else len(position_check)
],
)
chunks.append(chunk)
return chunks
@ -111,16 +135,25 @@ class DocxReader(SourceReaderABC):
for para in doc.paragraphs:
full_text.append(para.text)
return full_text
def _get_title_from_text(self, text: str) -> str:
text = text.strip()
title = text.split('\n')[0]
text = "\n".join(text.split('\n'))
return title,text
def invoke(self, input: Input, **kwargs) -> List[Output]:
def _get_title_from_text(self, text: str) -> str:
"""
Processes the input Docx file, extracts its text content, and generates a Chunk object.
Extracts the title from the provided text.
Args:
text (str): The text from which to extract the title.
Returns:
str: The extracted title and the remaining text.
"""
text = text.strip()
title = text.split("\n")[0]
text = "\n".join(text.split("\n"))
return title, text
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Processes the input Docx file, extracts its text content, and generates Chunk objects.
Args:
input (Input): The file path of the Docx file to be processed.
@ -136,9 +169,9 @@ class DocxReader(SourceReaderABC):
if not input:
raise ValueError("Input cannot be empty")
chunks = []
try:
doc = Document(input)
full_text = self._extract_text_from_docx(doc)
@ -148,32 +181,12 @@ class DocxReader(SourceReaderABC):
basename, _ = os.path.splitext(os.path.basename(input))
for text in full_text:
title,text = self._get_title_from_text(text)
chunk = Chunk(
id=Chunk.generate_hash_id(f"{basename}#{title}"),
name=f"{basename}#{title}",
content=text,
)
chunks.append(chunk)
if len(chunks) < 2:
chunks = self.outline_chunk(chunks,basename)
if len(chunks) < 2:
semantic_res = split_txt(content)
chunks = [Chunk(
id=Chunk.generate_hash_id(input+"#"+r[:10]),
name=basename+"#"+r[:10],
content=r,
) for r in semantic_res]
chunk = Chunk(
id=generate_hash_id(input),
name=basename,
content=content,
**{"documentId": basename, "documentName": basename},
)
chunks.append(chunk)
return chunks
if __name__== "__main__":
reader = DocxReader()
print(reader.output_types)
file_path = os.path.dirname(__file__)
res = reader.invoke(os.path.join(file_path,"../../../../tests/builder/data/test_docx.docx"))
print(res)

View File

@ -1,164 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import os
from typing import List, Type, Dict, Union
from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.model.chunk import Chunk
from kag.interface.builder.reader_abc import SourceReaderABC
from knext.common.base.runnable import Input, Output
from kag.common.llm.client import LLMClient
class JSONReader(SourceReaderABC):
"""
A class for reading JSON files, inheriting from `SourceReader`.
Supports converting JSON data into either a list of dictionaries or a list of Chunk objects.
Args:
output_types (Output): Specifies the output type, which can be "Dict" or "Chunk".
**kwargs: Additional keyword arguments passed to the parent class constructor.
"""
def __init__(self, output_type="Chunk", **kwargs):
super().__init__(**kwargs)
if output_type == "Dict":
self.output_types = Dict[str, str]
else:
self.output_types = Chunk
self.id_col = kwargs.get("id_col", "id")
self.name_col = kwargs.get("name_col", "name")
self.content_col = kwargs.get("content_col", "content")
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return self._output_types
@output_types.setter
def output_types(self, output_types):
self._output_types = output_types
@staticmethod
def _read_from_file(file_path: str) -> Union[dict, list]:
"""
Safely reads JSON from a file and returns its content.
Args:
file_path (str): The path to the JSON file.
Returns:
Union[dict, list]: The parsed JSON content.
Raises:
ValueError: If there is an error reading the JSON file.
"""
try:
with open(file_path, "r") as file:
return json.load(file)
except json.JSONDecodeError as e:
raise ValueError(f"Error reading JSON from file: {e}")
except FileNotFoundError as e:
raise ValueError(f"File not found: {e}")
@staticmethod
def _parse_json_string(json_string: str) -> Union[dict, list]:
"""
Parses a JSON string and returns its content.
Args:
json_string (str): The JSON string to parse.
Returns:
Union[dict, list]: The parsed JSON content.
Raises:
ValueError: If there is an error parsing the JSON string.
"""
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
raise ValueError(f"Error parsing JSON string: {e}")
def invoke(self, input: str, **kwargs) -> List[Output]:
"""
Parses the input string data and generates a list of Chunk objects or returns the original data.
This method supports receiving JSON-formatted strings. It extracts specific fields based on provided keyword arguments.
It can read from a file or directly parse a string. If the input data is in the expected format, it generates a list of Chunk objects;
otherwise, it throws a ValueError if the input is not a JSON array or object.
Args:
input (str): The input data, which can be a JSON string or a file path.
**kwargs: Keyword arguments used to specify the field names for ID, name, and content.
Returns:
List[Output]: A list of Chunk objects or the original data.
Raises:
ValueError: If the input data format is incorrect or parsing fails.
"""
id_col = kwargs.get("id_col", "id")
name_col = kwargs.get("name_col", "name")
content_col = kwargs.get("content_col", "content")
self.id_col = id_col
self.name_col = name_col
self.content_col = content_col
try:
if os.path.exists(input):
corpus = self._read_from_file(input)
else:
corpus = self._parse_json_string(input)
except ValueError as e:
raise e
if not isinstance(corpus, (list, dict)):
raise ValueError("Expected input to be a JSON array or object")
if isinstance(corpus, dict):
corpus = [corpus]
if self.output_types == Chunk:
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
for idx, item in enumerate(corpus):
if not isinstance(item, dict):
continue
chunk = Chunk(
id=item.get(self.id_col) or Chunk.generate_hash_id(f"{input}#{idx}"),
name=item.get(self.name_col) or f"{basename}#{idx}",
content=item.get(self.content_col),
)
chunks.append(chunk)
return chunks
else:
return corpus
if __name__ == "__main__":
reader = JSONReader()
json_string = '''[
{
"title": "test_json",
"text": "Test content"
}
]'''
chunks = reader.invoke(json_string,name_column="title",content_col = "text")
res = 1

View File

@ -12,24 +12,37 @@
import os
import bs4.element
import markdown
from bs4 import BeautifulSoup, Tag
from typing import List, Type
import logging
import re
import requests
import pandas as pd
from io import StringIO
from tenacity import stop_after_attempt, retry
from typing import List, Dict
from kag.interface.builder import SourceReaderABC
from kag.builder.model.chunk import Chunk, ChunkTypeEnum
from knext.common.base.runnable import Output, Input
from kag.interface import ReaderABC
from kag.builder.model.chunk import Chunk
from kag.interface import LLMClient
from kag.builder.prompt.analyze_table_prompt import AnalyzeTablePrompt
from knext.common.base.runnable import Output, Input
class MarkDownReader(SourceReaderABC):
logger = logging.getLogger(__name__)
class MarkdownNode:
def __init__(self, title: str, level: int, content: str = ""):
self.title = title
self.level = level
self.content = content
self.children: List[MarkdownNode] = []
self.tables: List[Dict] = [] # 存储表格数据
@ReaderABC.register("md")
@ReaderABC.register("md_reader")
class MarkDownReader(ReaderABC):
"""
A class for reading MarkDown files, inheriting from `SourceReader`.
Supports converting MarkDown data into a list of Chunk objects.
@ -41,352 +54,344 @@ class MarkDownReader(SourceReaderABC):
ALL_LEVELS = [f"h{x}" for x in range(1, 7)]
TABLE_CHUCK_FLAG = "<<<table_chuck>>>"
def __init__(self, cut_depth: int = 1, **kwargs):
def __init__(self, cut_depth: int = 3, llm: LLMClient = None, **kwargs):
super().__init__(**kwargs)
self.cut_depth = int(cut_depth)
self.llm_module = kwargs.get("llm_module", None)
self.llm = llm
self.analyze_table_prompt = AnalyzeTablePrompt(language="zh")
self.analyze_img_prompt = AnalyzeTablePrompt(language="zh")
@property
def input_types(self) -> Type[Input]:
def input_types(self):
return str
@property
def output_types(self) -> Type[Output]:
def output_types(self):
return Chunk
def to_text(self, level_tags):
"""
Converts parsed hierarchical tags into text content.
def solve_content(
self, id: str, title: str, content: str, **kwargs
) -> List[Output]:
# Convert Markdown to HTML with additional extensions for lists
html = markdown.markdown(
content, extensions=["tables", "nl2br", "sane_lists", "fenced_code"]
)
soup = BeautifulSoup(html, "html.parser")
Args:
level_tags (list): Parsed tags organized by Markdown heading levels and other tags.
def is_in_code_block(element):
"""Check if an element is inside a code block"""
parent = element.parent
while parent:
if parent.name in ["pre", "code"]:
return True
parent = parent.parent
return False
Returns:
str: Text content derived from the parsed tags.
"""
content = []
for item in level_tags:
if isinstance(item, list):
content.append(self.to_text(item))
else:
header, tag = item
if not isinstance(tag, Tag):
continue
elif tag.name in self.ALL_LEVELS:
content.append(
f"{header}-{tag.text}" if len(header) > 0 else tag.text
)
def process_text_with_links(element):
"""Process text containing links, preserving original markdown format"""
result = []
current_text = ""
for child in element.children:
if isinstance(child, Tag):
if child.name == "a":
# If there's previous text, add it first
if current_text:
result.append(current_text.strip())
current_text = ""
# Rebuild markdown format link
link_text = child.get_text().strip()
href = child.get("href", "")
title = child.get("title", "")
if title:
result.append(f'[{link_text}]({href} "{title}")')
else:
result.append(f"[{link_text}]({href})")
else:
current_text += child.get_text()
else:
content.append(self.tag_to_text(tag))
return "\n".join(content)
current_text += str(child)
def tag_to_text(self, tag: bs4.element.Tag):
"""
将html tag转换为text
如果是table输出markdown添加表格标记方便后续构建Chunk
:param tag:
:return:
"""
if tag.name == "table":
try:
html_table = str(tag)
table_df = pd.read_html(html_table)[0]
return f"{self.TABLE_CHUCK_FLAG}{table_df.to_markdown(index=False)}{self.TABLE_CHUCK_FLAG}"
except:
logging.warning("parse table tag to text error", exc_info=True)
return tag.text
if current_text:
result.append(current_text.strip())
@retry(stop=stop_after_attempt(5))
def analyze_table(self, table,analyze_mathod="human"):
if analyze_mathod == "llm":
if self.llm_module == None:
logging.INFO("llm_module is None, cannot use analyze_table")
return table
variables = {
"table": table
}
response = self.llm_module.invoke(
variables = variables,
prompt_op = self.analyze_table_prompt,
with_json_parse=False
)
if response is None or response == "" or response == []:
raise Exception("llm_module return None")
return response
else:
from io import StringIO
import pandas as pd
try:
df = pd.read_html(StringIO(table))[0]
except Exception as e:
logging.warning(f"analyze_table error: {e}")
return table
content = ""
for index, row in df.iterrows():
content+=f"{index+1}行的数据如下:"
for col_name, value in row.items():
content+=f"{col_name}的值为{value}"
content+='\n'
return " ".join(result)
# Initialize root node
root = MarkdownNode("root", 0)
stack = [root]
current_content = []
# Traverse all elements
all_elements = soup.find_all(
[
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"p",
"table",
"ul",
"ol",
"li",
"pre",
"code",
]
)
for element in all_elements:
if element.name.startswith("h") and not is_in_code_block(element):
# Only process headers that are not in code blocks
# Handle title logic
if current_content and stack[-1].title != "root":
stack[-1].content = "\n".join(current_content)
current_content = []
level = int(element.name[1])
title_text = process_text_with_links(element) # Process links in title
new_node = MarkdownNode(title_text, level)
while stack and stack[-1].level >= level:
stack.pop()
if stack:
stack[-1].children.append(new_node)
stack.append(new_node)
elif element.name in ["code"]:
# Preserve code blocks as is
text = element.get_text()
if text:
current_content.append(text)
elif element.name in ["ul", "ol"]:
continue
elif element.name == "li":
text = process_text_with_links(element) # Process links in list items
if text:
if element.find_parent("ol"):
index = len(element.find_previous_siblings("li")) + 1
current_content.append(f"{index}. {text}")
else:
current_content.append(f"* {text}")
elif element.name == "table":
# Process table
table_data = []
headers = []
if element.find("thead"):
for th in element.find("thead").find_all("th"):
headers.append(th.get_text().strip())
if element.find("tbody"):
for row in element.find("tbody").find_all("tr"):
row_data = {}
for i, td in enumerate(row.find_all("td")):
if i < len(headers):
row_data[headers[i]] = td.get_text().strip()
table_data.append(row_data)
# Add table to current node
if stack[-1].title != "root":
stack[-1].tables.append({"headers": headers, "data": table_data})
elif element.name == "p":
text = process_text_with_links(element) # Process links in paragraphs
if text:
if not text.startswith("* ") and not re.match(r"^\d+\. ", text):
current_content.append(text)
# Process content of the last node
if current_content and stack[-1].title != "root":
stack[-1].content = "\n".join(current_content)
outputs = self._convert_to_outputs(root, id)
return outputs
def _convert_to_outputs(
self,
node: MarkdownNode,
id: str,
parent_id: str = None,
parent_titles: List[str] = None,
parent_contents: List[str] = None,
) -> List[Output]:
def convert_table_to_markdown(headers, data):
"""Convert table data to markdown format"""
if not headers or not data:
return ""
# Build header row
header_row = " | ".join(headers)
# Build separator row
separator = " | ".join(["---"] * len(headers))
# Build data rows
data_rows = []
for row in data:
row_values = [str(row.get(header, "")) for header in headers]
data_rows.append(" | ".join(row_values))
# Combine all rows
table_md = f"\n| {header_row} |\n| {separator} |\n"
table_md += "\n".join(f"| {row} |" for row in data_rows)
return table_md + "\n"
def collect_tables(n: MarkdownNode):
"""Collect tables from node and its children"""
tables = []
table_md = []
if n.tables:
for table in n.tables:
tables.append(table)
table_md.append(
convert_table_to_markdown(table["headers"], table["data"])
)
for child in n.children:
child_tables, child_table_md = collect_tables(child)
tables.extend(child_tables)
table_md.extend(child_table_md)
return tables, table_md
def collect_children_content(n: MarkdownNode):
"""Collect content from node and its children"""
content = []
if n.content:
content.append(n.content)
# Add current node's table content
for table in n.tables:
content.append(
convert_table_to_markdown(table["headers"], table["data"])
)
# Process child nodes recursively
for child in n.children:
content.extend(collect_children_content(child))
return content
@retry(stop=stop_after_attempt(5))
def analyze_img(self, img_url):
response = requests.get(img_url)
response.raise_for_status()
image_data = response.content
outputs = []
if parent_titles is None:
parent_titles = []
if parent_contents is None:
parent_contents = []
pass
current_titles = parent_titles + ([node.title] if node.title != "root" else [])
def replace_table(self, content: str):
pattern = r"<table[^>]*>([\s\S]*?)<\/table>"
for match in re.finditer(pattern, content):
table = match.group(0)
table = self.analyze_table(table)
content = content.replace(match.group(1), table)
return content
# If current node level equals target level, create output
if node.level >= self.cut_depth:
full_title = " / ".join(current_titles)
def replace_img(self, content: str):
pattern = r"<img[^>]*src=[\"\']([^\"\']*)[\"\']"
for match in re.finditer(pattern, content):
img_url = match.group(1)
img_msg = self.analyze_img(img_url)
content = content.replace(match.group(0), img_msg)
return content
# Merge content: parent content + current content
all_content = parent_contents + ([node.content] if node.content else [])
def extract_table(self, level_tags, header=""):
"""
Extracts tables from the parsed hierarchical tags along with their headers.
Args:
level_tags (list): Parsed tags organized by Markdown heading levels and other tags.
header (str): Current header text being processed.
Returns:
list: A list of tuples, each containing the table's header, context text, and the table tag.
"""
tables = []
for idx, item in enumerate(level_tags):
if isinstance(item, list):
tables += self.extract_table(item, header)
else:
tag = item[1]
if not isinstance(tag, Tag):
continue
if tag.name in self.ALL_LEVELS:
header = f"{header}-{tag.text}" if len(header) > 0 else tag.text
if tag.name == "table":
if idx - 1 >= 0:
context = level_tags[idx - 1]
if isinstance(context, tuple):
tables.append((header, context[1].text, tag))
else:
tables.append((header, "", tag))
return tables
def parse_level_tags(
self,
level_tags: list,
level: str,
parent_header: str = "",
cur_header: str = "",
):
"""
Recursively parses level tags to organize them into a structured format.
Args:
level_tags (list): A list of tags to be parsed.
level (str): The current level being processed.
parent_header (str): The header of the parent tag.
cur_header (str): The header of the current tag.
Returns:
list: A structured representation of the parsed tags.
"""
if len(level_tags) == 0:
return []
output = []
prefix_tags = []
while len(level_tags) > 0:
tag = level_tags[0]
if tag.name in self.ALL_LEVELS:
break
else:
prefix_tags.append((parent_header, level_tags.pop(0)))
if len(prefix_tags) > 0:
output.append(prefix_tags)
cur = []
while len(level_tags) > 0:
tag = level_tags[0]
if tag.name not in self.ALL_LEVELS:
cur.append((parent_header, level_tags.pop(0)))
else:
if tag.name > level:
cur += self.parse_level_tags(
level_tags,
tag.name,
f"{parent_header}-{cur_header}"
if len(parent_header) > 0
else cur_header,
tag.name,
)
elif tag.name == level:
if len(cur) > 0:
output.append(cur)
cur = [(parent_header, level_tags.pop(0))]
cur_header = tag.text
else:
if len(cur) > 0:
output.append(cur)
return output
if len(cur) > 0:
output.append(cur)
return output
def cut(self, level_tags, cur_level, final_level):
"""
Cuts the provided level tags into chunks based on the specified levels.
Args:
level_tags (list): A list of tags to be cut.
cur_level (int): The current level in the hierarchy.
final_level (int): The final level to which the tags should be cut.
Returns:
list: A list of cut chunks.
"""
output = []
if cur_level == final_level:
cur_prefix = []
for sublevel_tags in level_tags:
if (
isinstance(sublevel_tags, tuple)
):
cur_prefix.append(self.to_text([sublevel_tags,]))
else:
break
cur_prefix = "\n".join(cur_prefix)
if len(cur_prefix) > 0:
output.append(cur_prefix)
for sublevel_tags in level_tags:
if isinstance(sublevel_tags, list):
output.append(cur_prefix + "\n" + self.to_text(sublevel_tags))
return output
else:
cur_prefix = []
for sublevel_tags in level_tags:
if (
isinstance(sublevel_tags, tuple)
):
cur_prefix.append(sublevel_tags[1].text)
else:
break
cur_prefix = "\n".join(cur_prefix)
if len(cur_prefix) > 0:
output.append(cur_prefix)
for sublevel_tags in level_tags:
if isinstance(sublevel_tags, list):
output += self.cut(sublevel_tags, cur_level + 1, final_level)
return output
def solve_content(self, id: str, title: str, content: str, **kwargs) -> List[Output]:
"""
Converts Markdown content into structured chunks.
Args:
id (str): An identifier for the content.
title (str): The title of the content.
content (str): The Markdown formatted content to be processed.
Returns:
List[Output]: A list of processed content chunks.
"""
html_content = markdown.markdown(
content, extensions=["markdown.extensions.tables"]
)
# html_content = self.replace_table(html_content)
soup = BeautifulSoup(html_content, "html.parser")
if soup is None:
raise ValueError("The MarkDown file appears to be empty or unreadable.")
top_level = None
for level in self.ALL_LEVELS:
tmp = soup.find_all(level)
if len(tmp) > 0:
top_level = level
break
if top_level is None:
chunk = Chunk(
id=Chunk.generate_hash_id(str(id)),
name=title,
content=soup.text,
ref=kwargs.get("ref", ""),
)
return [chunk]
tags = [tag for tag in soup.children if isinstance(tag, Tag)]
level_tags = self.parse_level_tags(tags, top_level)
cutted = self.cut(level_tags, 0, self.cut_depth)
chunks = []
for idx, content in enumerate(cutted):
chunk = None
if self.TABLE_CHUCK_FLAG in content:
chunk = self.get_table_chuck(content, title, id, idx)
chunk.ref = kwargs.get("ref", "")
else:
chunk = Chunk(
id=Chunk.generate_hash_id(f"{id}#{idx}"),
name=f"{title}#{idx}",
content=content,
ref=kwargs.get("ref", ""),
# Add current node's table content
for table in node.tables:
all_content.append(
convert_table_to_markdown(table["headers"], table["data"])
)
chunks.append(chunk)
return chunks
def get_table_chuck(self, table_chunk_str: str, title: str, id: str, idx: int) -> Chunk:
"""
convert table chunk
:param table_chunk_str:
:return:
"""
table_chunk_str = table_chunk_str.replace("\\N", "")
pattern = f"{self.TABLE_CHUCK_FLAG}(.*){self.TABLE_CHUCK_FLAG}"
matches = re.findall(pattern, table_chunk_str, re.DOTALL)
if not matches or len(matches) <= 0:
# 找不到表格信息按照Text Chunk处理
return Chunk(
id=Chunk.generate_hash_id(f"{id}#{idx}"),
name=f"{title}#{idx}",
content=table_chunk_str,
# Add all child node content (including tables)
for child in node.children:
child_content = collect_children_content(child)
all_content.extend(child_content)
current_output = Chunk(
id=f"{id}_{len(outputs)}",
parent_id=parent_id,
name=full_title,
content="\n".join(filter(None, all_content)),
)
table_markdown_str = matches[0]
html_table_str = markdown.markdown(table_markdown_str, extensions=["markdown.extensions.tables"])
try:
df = pd.read_html(html_table_str)[0]
except Exception as e:
logging.warning(f"get_table_chuck error: {e}")
df = pd.DataFrame()
# 确认是表格Chunk去除内容中的TABLE_CHUCK_FLAG
replaced_table_text = re.sub(pattern, f'\n{table_markdown_str}\n', table_chunk_str, flags=re.DOTALL)
return Chunk(
id=Chunk.generate_hash_id(f"{id}#{idx}"),
name=f"{title}#{idx}",
content=replaced_table_text,
type=ChunkTypeEnum.Table,
csv_data=df.to_csv(index=False),
)
# Collect table data and convert to markdown format
all_tables = []
table_contents = []
if node.tables:
for table in node.tables:
all_tables.append(table)
table_contents.append(
convert_table_to_markdown(table["headers"], table["data"])
)
def invoke(self, input: Input, **kwargs) -> List[Output]:
for child in node.children:
child_tables, child_table_md = collect_tables(child)
all_tables.extend(child_tables)
table_contents.extend(child_table_md)
if all_tables:
current_output.metadata = {"tables": all_tables}
current_output.table = "\n".join(
table_contents
) # Save all tables in markdown format
outputs.append(current_output)
# If current node level is less than target level, continue traversing
elif node.level < self.cut_depth:
# Check if any subtree contains target level nodes
has_target_level = False
current_contents = parent_contents + (
[node.content] if node.content else []
)
# Add current node's tables to content
for table in node.tables:
current_contents.append(
convert_table_to_markdown(table["headers"], table["data"])
)
for child in node.children:
child_outputs = self._convert_to_outputs(
child, id, parent_id, current_titles, current_contents
)
if child_outputs:
has_target_level = True
outputs.extend(child_outputs)
# If no target level nodes found and current node is not root, output current node
if not has_target_level and node.title != "root":
full_title = " / ".join(current_titles)
all_content = current_contents
for child in node.children:
child_content = collect_children_content(child)
all_content.extend(child_content)
current_output = Chunk(
id=f"{id}_{len(outputs)}",
parent_id=parent_id,
name=full_title,
content="\n".join(filter(None, all_content)),
)
# Collect table data and convert to markdown format
all_tables = []
table_contents = []
if node.tables:
for table in node.tables:
all_tables.append(table)
table_contents.append(
convert_table_to_markdown(table["headers"], table["data"])
)
for child in node.children:
child_tables, child_table_md = collect_tables(child)
all_tables.extend(child_tables)
table_contents.extend(child_table_md)
if all_tables:
current_output.metadata = {"tables": all_tables}
current_output.table = "\n".join(
table_contents
) # Save all tables in markdown format
outputs.append(current_output)
return outputs
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Processes a Markdown file and returns its content as structured chunks.
@ -411,4 +416,55 @@ class MarkDownReader(SourceReaderABC):
basename, _ = os.path.splitext(os.path.basename(file_path))
chunks = self.solve_content(input, basename, content)
length_500_list = []
length_1000_list = []
length_5000_list = []
length_smal_list = []
for chunk in chunks:
if chunk.content is not None:
if len(chunk.content) > 5000:
length_5000_list.append(chunk)
elif len(chunk.content) > 1000:
length_1000_list.append(chunk)
elif len(chunk.content) > 500:
length_500_list.append(chunk)
elif len(chunk.content) <= 500:
length_smal_list.append(chunk)
return chunks
@ReaderABC.register("yuque")
@ReaderABC.register("yuque_reader")
class YuequeReader(MarkDownReader):
"""
A class for parsing Yueque documents into Chunk objects.
This class inherits from MarkDownParser and provides the functionality to process Yueque documents,
extract their content, and convert it into a list of Chunk objects.
"""
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Processes the input Yueque document and converts it into a list of Chunk objects.
Args:
input (Input): The input string containing the Yueque token and URL.
**kwargs: Additional keyword arguments, currently unused but kept for potential future expansion.
Returns:
List[Output]: A list of Chunk objects representing the parsed content.
Raises:
HTTPError: If the request to the Yueque URL fails.
"""
token, url = input.split("@", 1)
headers = {"X-Auth-Token": token}
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx)
data = response.json()["data"]
id = data.get("id", "")
title = data.get("title", "")
content = data.get("body", "")
chunks = self.solve_content(id, title, content)
return chunks

View File

@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import List
from kag.interface import ReaderABC
from knext.common.base.runnable import Input, Output
from kag.builder.component.reader.txt_reader import TXTReader
from kag.builder.component.reader.pdf_reader import PDFReader
from kag.builder.component.reader.docx_reader import DocxReader
from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.component.reader.dict_reader import DictReader
@ReaderABC.register("mix", as_default=True)
@ReaderABC.register("mix_reader")
class MixReader(ReaderABC):
"""
A reader class that can handle multiple types of inputs by delegating to specific readers.
This class initializes with a mapping of file types to their respective readers.
It provides a method to invoke the appropriate reader based on the input type.
"""
def __init__(
self,
txt_reader: TXTReader = None,
pdf_reader: PDFReader = None,
docx_reader: DocxReader = None,
md_reader: MarkDownReader = None,
dict_reader: DictReader = None,
):
"""
Initializes the MixReader with a mapping of file types to their respective readers.
Args:
txt_reader (TXTReader, optional): Reader for .txt files. Defaults to None.
pdf_reader (PDFReader, optional): Reader for .pdf files. Defaults to None.
docx_reader (DocxReader, optional): Reader for .docx files. Defaults to None.
md_reader (MarkDownReader, optional): Reader for .md files. Defaults to None.
dict_reader (DictReader, optional): Reader for dictionary inputs. Defaults to None.
"""
super().__init__()
self.parse_map = {
"txt": txt_reader,
"pdf": pdf_reader,
"docx": docx_reader,
"md": md_reader,
"dict": dict_reader,
}
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the appropriate reader based on the input type.
Args:
input (Input): The input to be parsed. This can be a file path or a dictionary.
**kwargs: Additional keyword arguments to be passed to the reader.
Returns:
List[Output]: A list of parsed outputs.
Raises:
ValueError: If the input is empty.
FileNotFoundError: If the input file does not exist.
NotImplementedError: If the file suffix is not supported.
KeyError: If the reader for the given file type is not correctly configured.
"""
if not input:
raise ValueError("Input cannot be empty")
if isinstance(input, dict):
reader_type = "dict"
else:
if os.path.exists(input):
raise FileNotFoundError(f"File {input} not found.")
file_suffix = input.split(".")[-1]
if file_suffix not in self.parse_map:
raise NotImplementedError(
f"File suffix {file_suffix} not supported yet."
)
reader_type = file_suffix
reader = self.reader_map[reader_type]
if reader is None:
raise KeyError(f"{reader_type} reader not correctly configured.")
return self.parse_map[file_suffix]._invoke(input)

View File

@ -12,28 +12,26 @@
import os
import re
from typing import List, Sequence, Type, Union
from typing import List, Sequence, Union
import pdfminer.layout # noqa
from langchain_community.document_loaders import PyPDFLoader
import pdfminer.layout
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
from kag.builder.prompt.outline_prompt import OutlinePrompt
from kag.interface import ReaderABC
from kag.builder.prompt.outline_prompt import OutlinePrompt
from kag.interface import LLMClient
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.utils import generate_hash_id
from knext.common.base.runnable import Output
from pdfminer.high_level import extract_text
from pdfminer.high_level import extract_pages
from pdfminer.layout import LTTextContainer, LTPage
from pdfminer.pdfparser import PDFParser
from pdfminer.pdfdocument import PDFDocument
from pdfminer.layout import LAParams,LTTextBox
from pdfminer.pdfpage import PDFPage
from pdfminer.pdfparser import PDFParser
from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
from pdfminer.converter import PDFPageAggregator
from pdfminer.pdfpage import PDFTextExtractionNotAllowed
import pdfminer
import pdfminer # noqa
import PyPDF2
import logging
@ -41,34 +39,207 @@ import logging
logger = logging.getLogger(__name__)
class PDFReader(SourceReaderABC):
@ReaderABC.register("pdf")
@ReaderABC.register("pdf_reader")
class PDFReader(ReaderABC):
"""
A PDF reader class that inherits from SourceReader.
A class for reading PDF files into a list of text chunks, inheriting from `ReaderABC`.
Attributes:
if_split (bool): Whether to split the content by pages. Default is False.
use_pypdf (bool): Whether to use PyPDF2 for processing PDF files. Default is True.
This class is responsible for parsing PDF files and converting them into a list of Chunk objects.
It inherits from `ReaderABC` and overrides the necessary methods to handle PDF-specific operations.
"""
def __init__(self, **kwargs):
def __init__(
self,
cut_depth: int = 3,
outline_flag: bool = True,
is_ocr: bool = False,
llm: LLMClient = None,
**kwargs,
):
super().__init__(**kwargs)
self.split_level = kwargs.get("split_level", 3)
self.split_using_outline = kwargs.get("split_using_outline", True)
self.outline_flag = True
self.llm = self._init_llm()
language = os.getenv("KAG_PROMPT_LANGUAGE", "zh")
self.cut_depth = cut_depth
self.outline_flag = outline_flag
self.is_ocr = is_ocr
self.llm = llm
language = KAG_PROJECT_CONF.language
self.prompt = OutlinePrompt(language)
@property
def input_types(self) -> Type[Input]:
def input_types(self):
return str
@property
def output_types(self) -> Type[Output]:
def output_types(self):
return Chunk
def outline_chunk(self, chunk: Union[Chunk, List[Chunk]],basename) -> List[Chunk]:
def _get_full_outlines(self):
outlines = self.pdf_reader.outline
level_outlines = []
def _extract_outline_page_numbers(outlines, level=0):
for outline in outlines:
if isinstance(outline, list):
_extract_outline_page_numbers(outline, level + 1)
else:
title = outline.title
page_number = self.pdf_reader.get_destination_page_number(outline)
level_outlines.append((title, level, page_number, 0))
_extract_outline_page_numbers(outlines)
for idx, outline in enumerate(level_outlines):
level_outlines[idx] = (
outline[0],
outline[1],
outline[2],
level_outlines[idx + 1][2] if idx + 1 < len(level_outlines) else -1,
)
return level_outlines
def extract_content_from_outline(
self, page_contents, level_outlines
) -> List[Chunk]:
total_content = "".join(page_contents)
def get_content_start(outline, page_contents):
page_start = outline[2]
page_end = outline[3]
previous_pages_length = sum(
len(content) for content in page_contents[:page_start]
)
find_content = "".join(
page_contents[page_start : page_end + 1 if page_end != -1 else None]
)
# 标准化标题中的特殊字符
def normalize_text(text):
# 将破折号"—"转换为中文数字"一"
text = text.replace("", "")
# 可以添加其他中英文标点的统一转换
text = re.sub(r"", "[", text)
text = re.sub(r"", "]", text)
text = re.sub(r"", "(", text)
text = re.sub(r"", ")", text)
return text
outline = (normalize_text(outline[0]), outline[1], outline[2], outline[3])
def fuzzy_search(pattern, text, threshold=0.90):
from difflib import SequenceMatcher
pattern_len = len(pattern)
for i in range(len(text) - pattern_len + 1):
substring = text[i : i + pattern_len]
similarity = SequenceMatcher(None, pattern, substring).ratio()
if similarity >= threshold:
return i
return -1
# 先尝试使用原始标题进行模糊匹配
title_with_spaces = outline[0].strip()
fuzzy_match_pos = fuzzy_search(title_with_spaces, find_content)
if fuzzy_match_pos != -1:
return previous_pages_length + fuzzy_match_pos
# 如果没找到,尝试使用去除所有空格的标题
title_no_spaces = title_with_spaces.replace(" ", "")
find_content_no_spaces = find_content.replace(" ", "")
fuzzy_match_pos = fuzzy_search(title_no_spaces, find_content_no_spaces)
if fuzzy_match_pos != -1:
# 计算原始文本中的实际位置
original_pos = 0
no_spaces_pos = 0
while no_spaces_pos < fuzzy_match_pos:
if find_content[original_pos] != " ":
no_spaces_pos += 1
original_pos += 1
return previous_pages_length + original_pos
# 在扩展范围内进行模糊匹配
extended_content = "".join(
page_contents[
max(0, page_start - 1) : page_end if page_end != -1 else None
]
)
fuzzy_match_pos = fuzzy_search(title_with_spaces, extended_content)
if fuzzy_match_pos != -1:
extended_previous_length = sum(
len(content) for content in page_contents[: max(0, page_start - 1)]
)
return extended_previous_length + fuzzy_match_pos
# 最后尝试不带空格的扩展内容
extended_content_no_spaces = extended_content.replace(" ", "")
fuzzy_match_pos = fuzzy_search(title_no_spaces, extended_content_no_spaces)
if fuzzy_match_pos != -1:
original_pos = 0
no_spaces_pos = 0
while no_spaces_pos < fuzzy_match_pos:
if extended_content[original_pos] != " ":
no_spaces_pos += 1
original_pos += 1
extended_previous_length = sum(
len(content) for content in page_contents[: max(0, page_start - 1)]
)
return extended_previous_length + original_pos
return -1
final_content = []
for idx, outline in enumerate(level_outlines):
start = get_content_start(outline, page_contents)
next_start = (
get_content_start(level_outlines[idx + 1], page_contents)
if idx + 1 < len(level_outlines)
else -1
)
if start >= 0 and next_start >= 0:
content = total_content[start:next_start]
final_content.append(
(outline[0], outline[1], start, next_start, content)
)
elif start >= 0 and next_start < 0 and idx + 1 == len(level_outlines):
content = total_content[start:]
final_content.append((outline[0], outline[1], start, -1, content))
return final_content
def convert_finel_content_to_chunks(self, final_content):
def create_chunk(title, content, basename):
return Chunk(
id=generate_hash_id(f"{basename}#{title}"),
name=f"{basename}#{title}",
content=content,
sub_chunks=[],
)
level_map = {}
chunks = []
for title, level, start, end, content in final_content:
chunk = create_chunk(
title, content, os.path.splitext(os.path.basename(self.fd.name))[0]
)
chunks.append(chunk)
if level == 0:
level_map[0] = chunk
else:
parent_level = level - 1
while parent_level >= 0:
if parent_level in level_map:
level_map[parent_level].sub_chunks.append(chunk)
break
parent_level -= 1
level_map[level] = chunk
return chunks
def outline_chunk(self, chunk: Union[Chunk, List[Chunk]], basename) -> List[Chunk]:
if isinstance(chunk, Chunk):
chunk = [chunk]
outlines = []
@ -76,26 +247,30 @@ class PDFReader(SourceReaderABC):
outline = self.llm.invoke({"input": c.content}, self.prompt)
outlines.extend(outline)
content = "\n".join([c.content for c in chunk])
chunks = self.sep_by_outline(content, outlines,basename)
chunks = self.sep_by_outline(content, outlines, basename)
return chunks
def sep_by_outline(self,content,outlines,basename):
def sep_by_outline(self, content, outlines, basename):
position_check = []
for outline in outlines:
start = content.find(outline)
position_check.append((outline,start))
position_check.append((outline, start))
chunks = []
for idx,pc in enumerate(position_check):
for idx, pc in enumerate(position_check):
chunk = Chunk(
id = Chunk.generate_hash_id(f"{basename}#{pc[0]}"),
id=generate_hash_id(f"{basename}#{pc[0]}"),
name=f"{basename}#{pc[0]}",
content=content[pc[1]:position_check[idx+1][1] if idx+1 < len(position_check) else len(position_check)],
content=content[
pc[1] : (
position_check[idx + 1][1]
if idx + 1 < len(position_check)
else len(position_check)
)
],
)
chunks.append(chunk)
return chunks
@staticmethod
def _process_single_page(
page: str,
@ -149,7 +324,7 @@ class PDFReader(SourceReaderABC):
text += element.get_text()
return text
def invoke(self, input: str, **kwargs) -> Sequence[Output]:
def _invoke(self, input: str, **kwargs) -> Sequence[Output]:
"""
Processes a PDF file, splitting or extracting content based on configuration.
@ -170,85 +345,140 @@ class PDFReader(SourceReaderABC):
if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.")
self.fd = open(input, "rb")
self.parser = PDFParser(self.fd)
self.document = PDFDocument(self.parser)
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
# get outline
self.fd = None
try:
outlines = self.document.get_outlines()
except Exception as e:
logger.warning(f"loading PDF file: {e}")
self.outline_flag = False
if not self.outline_flag:
self.fd = open(input, "rb")
self.pdf_reader = PyPDF2.PdfReader(self.fd)
self.level_outlines = self._get_full_outlines()
self.parser = PDFParser(self.fd)
self.document = PDFDocument(self.parser)
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
with open(input, "rb") as file:
for idx, page_layout in enumerate(extract_pages(file)):
content = ""
for element in page_layout:
if hasattr(element, "get_text"):
content = content + element.get_text()
# get outline
try:
outlines = self.document.get_outlines()
except Exception as e:
logger.warning(f"loading PDF file: {e}")
self.outline_flag = False
if not self.outline_flag:
with open(input, "rb") as file:
for idx, page_layout in enumerate(extract_pages(file)):
content = ""
for element in page_layout:
if hasattr(element, "get_text"):
content = content + element.get_text()
chunk = Chunk(
id=generate_hash_id(f"{basename}#{idx}"),
name=f"{basename}#{idx}",
content=content,
)
chunks.append(chunk)
# try:
# outline_chunks = self.outline_chunk(chunks, basename)
# except Exception as e:
# raise RuntimeError(f"Error loading PDF file: {e}")
# if len(outline_chunks) > 0:
# chunks = outline_chunks
elif True:
split_words = []
page_contents = []
with open(input, "rb") as file:
for idx, page_layout in enumerate(extract_pages(file)):
content = ""
for element in page_layout:
if hasattr(element, "get_text"):
content = content + element.get_text()
content = content.replace("\n", "")
page_contents.append(content)
# 使用正则表达式移除所有空白字符(包括空格、制表符、换行符等)
page_contents = [
re.sub(r"\s+", "", content) for content in page_contents
]
page_contents = [
re.sub(r"[\s\u200b\u200c\u200d\ufeff]+", "", content)
for content in page_contents
]
page_contents = ["".join(content.split()) for content in page_contents]
final_content = self.extract_content_from_outline(
page_contents, self.level_outlines
)
chunks = self.convert_finel_content_to_chunks(final_content)
else:
for item in outlines:
level, title, dest, a, se = item
split_words.append(title.strip().replace(" ", ""))
# save the outline position in content
try:
text = extract_text(input)
except Exception as e:
raise RuntimeError(f"Error loading PDF file: {e}")
cleaned_pages = [
self._process_single_page(x, "", False, False) for x in text
]
sentences = []
for cleaned_page in cleaned_pages:
sentences += cleaned_page
content = "".join(sentences)
positions = [(input, 0)]
for split_word in split_words:
pattern = re.compile(split_word)
start = 0
for i, match in enumerate(re.finditer(pattern, content)):
if i <= 1:
start, end = match.span()
if start > 0:
positions.append((split_word, start))
for idx, position in enumerate(positions):
chunk = Chunk(
id=Chunk.generate_hash_id(f"{basename}#{idx}"),
name=f"{basename}#{idx}",
content=content,
id=generate_hash_id(f"{basename}#{position[0]}"),
name=f"{basename}#{position[0]}",
content=content[
position[1] : (
positions[idx + 1][1]
if idx + 1 < len(positions)
else None
)
],
)
chunks.append(chunk)
try:
outline_chunks = self.outline_chunk(chunks, basename)
except Exception as e:
raise RuntimeError(f"Error loading PDF file: {e}")
if len(outline_chunks) > 0:
chunks = outline_chunks
else:
split_words = []
for item in outlines:
level, title, dest, a, se = item
split_words.append(title.strip().replace(" ",""))
# save the outline position in content
try:
text = extract_text(input)
except Exception as e:
raise RuntimeError(f"Error loading PDF file: {e}")
# # 保存中间结果到文件
# import pickle
cleaned_pages = [
self._process_single_page(x, "", False, False) for x in text
]
sentences = []
for cleaned_page in cleaned_pages:
sentences += cleaned_page
# with open("debug_data.pkl", "wb") as f:
# pickle.dump(
# {"page_contents": page_contents, "level_outlines": self.level_outlines},
# f,
# )
content = "".join(sentences)
positions = [(input,0)]
for split_word in split_words:
pattern = re.compile(split_word)
for i,match in enumerate(re.finditer(pattern, content)):
if i == 1:
start, end = match.span()
positions.append((split_word,start))
for idx,position in enumerate(positions):
chunk = Chunk(
id = Chunk.generate_hash_id(f"{basename}#{position[0]}"),
name=f"{basename}#{position[0]}",
content=content[position[1]:positions[idx+1][1] if idx+1 < len(positions) else None],
)
chunks.append(chunk)
return chunks
return chunks
except Exception as e:
raise RuntimeError(f"Error loading PDF file: {e}")
finally:
if self.fd:
self.fd.close()
if __name__ == '__main__':
reader = PDFReader(split_using_outline=True)
pdf_path = os.path.join(os.path.dirname(__file__),"../../../../tests/builder/data/aiwen.pdf")
chunk = reader.invoke(pdf_path)
print(chunk)
if __name__ == "__main__":
pdf_reader = PDFReader()
pdf_path = os.path.join(
os.path.dirname(__file__), "../../../../tests/builder/data/aiwen.pdf"
)
pdf_path = "/Users/zhangxinhong.zxh/Downloads/labor-law-v5.pdf"
# pdf_path = "/Users/zhangxinhong.zxh/Downloads/toaz.info-5dsm-5-pr_56e68a629dc4fe62699960dd5afbe362.pdf"
chunk = pdf_reader.invoke(pdf_path)
a = 1

View File

@ -11,29 +11,27 @@
# or implied.
import os
from typing import List, Type
from typing import List
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from kag.interface import ReaderABC
from kag.common.utils import generate_hash_id
from knext.common.base.runnable import Input, Output
class TXTReader(SourceReaderABC):
@ReaderABC.register("txt")
@ReaderABC.register("txt_reader")
class TXTReader(ReaderABC):
"""
A PDF reader class that inherits from SourceReader.
A class for parsing text files or text content into Chunk objects.
This class inherits from ReaderABC and provides the functionality to read text content,
whether it is from a file or directly provided as a string, and convert it into a list of Chunk objects.
"""
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return Chunk
def invoke(self, input: Input, **kwargs) -> List[Output]:
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
The main method for processing text reading. This method reads the content of the input (which can be a file path or text content) and converts it into a Chunk object.
The main method for processing text reading. This method reads the content of the input (which can be a file path or text content) and converts it into chunks.
Args:
input (Input): The input string, which can be the path to a text file or direct text content.
@ -51,7 +49,7 @@ class TXTReader(SourceReaderABC):
try:
if os.path.exists(input):
with open(input, "r", encoding='utf-8') as f:
with open(input, "r", encoding="utf-8") as f:
content = f.read()
else:
content = input
@ -60,7 +58,7 @@ class TXTReader(SourceReaderABC):
basename, _ = os.path.splitext(os.path.basename(input))
chunk = Chunk(
id=Chunk.generate_hash_id(input),
id=generate_hash_id(input),
name=basename,
content=content,
)

View File

@ -1,67 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import requests
from typing import Type, List
from kag.builder.component.reader import MarkDownReader
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
from kag.common.llm.client import LLMClient
class YuqueReader(SourceReaderABC):
def __init__(self, token: str, **kwargs):
super().__init__(**kwargs)
self.token = token
self.markdown_reader = MarkDownReader(**kwargs)
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Chunk
@staticmethod
def get_yuque_api_data(token, url):
headers = {"X-Auth-Token": token}
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx)
return response.json()["data"] # Assuming the API returns JSON data
except requests.exceptions.HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except requests.exceptions.RequestException as err:
print(f"Error occurred: {err}")
except Exception as err:
print(f"An error occurred: {err}")
def invoke(self, input: str, **kwargs) -> List[Output]:
if not input:
raise ValueError("Input cannot be empty")
url: str = input
data = self.get_yuque_api_data(self.token, url)
id = data.get("id", "")
title = data.get("title", "")
content = data.get("body", "")
chunks = self.markdown_reader.solve_content(id, title, content)
return chunks

View File

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Dict, List
import pandas as pd
from kag.interface import ScannerABC
from kag.common.utils import generate_hash_id
from knext.common.base.runnable import Input, Output
@ScannerABC.register("csv")
@ScannerABC.register("csv_scanner")
class CSVScanner(ScannerABC):
def __init__(
self,
header: bool = True,
col_names: List[str] = None,
col_ids: List[int] = None,
rank: int = 0,
world_size: int = 1,
):
super().__init__(rank=rank, world_size=world_size)
self.header = header
self.col_names = col_names
self.col_ids = col_ids
@property
def input_types(self) -> Input:
return str
@property
def output_types(self) -> Output:
return Dict
def load_data(self, input: Input, **kwargs) -> List[Output]:
"""
Loads data from a CSV file and converts it into a list of dictionaries.
Args:
input (Input): The input file path to the CSV file.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of dictionaries containing the processed data.
"""
input = self.download_data(input)
if self.header:
data = pd.read_csv(input, dtype=str)
else:
data = pd.read_csv(input, dtype=str, header=None)
col_keys = self.col_names if self.col_names else self.col_ids
if col_keys is None:
return data.to_dict(orient="records")
contents = []
for _, row in data.iterrows():
for k, v in row.items():
if k in col_keys:
v = str(v)
name = v[:5] + "..." + v[-5:]
contents.append(
{"id": generate_hash_id(v), "name": name, "content": v}
)
return contents

View File

@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import os
from typing import List, Type, Dict
from kag.interface import ScannerABC
from knext.common.base.runnable import Input, Output
@ScannerABC.register("hotpotqa")
@ScannerABC.register("hotpotqa_dataset_scanner")
class HotpotqaCorpusScanner(ScannerABC):
"""
A class for reading HotpotQA dataset and converting it into a list of dictionaries, inheriting from `ScannerABC`.
This class is responsible for reading HotpotQA corpus and converting it into a list of dictionaries.
It inherits from `ScannerABC` and overrides the necessary methods to handle HotpotQA-specific operations.
"""
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return Dict
def load_data(self, input: Input, **kwargs) -> List[Output]:
"""
Loads data from a HotpotQA corpus file or JSON string and returns it as a list of dictionaries.
This method reads HotpotQA corpus data from a file or parses a JSON string and returns it as a list of dictionaries.
If the input is a file path, it reads the file; if the input is a JSON string, it parses the string.
Args:
input (Input): The HotpotQA corpus file path or JSON string to load.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of dictionaries, where each dictionary represents a HotpotQA item.
"""
if os.path.exists(str(input)):
with open(input, "r") as f:
corpus = json.load(f)
else:
corpus = json.loads(input)
data = []
for item_key, item_value in corpus.items():
data.append(
{"id": item_key, "name": item_key, "content": "\n".join(item_value)}
)
return data
@ScannerABC.register("musique")
@ScannerABC.register("2wiki")
@ScannerABC.register("musique_dataset_scanner")
@ScannerABC.register("2wiki_dataset_scanner")
class MusiqueCorpusScanner(ScannerABC):
"""
A class for reading Musique/2Wiki dataset and converting it into a list of dictionaries, inheriting from `ScannerABC`.
This class is responsible for reading Musique/2Wiki corpus and converting it into a list of dictionaries.
It inherits from `ScannerABC` and overrides the necessary methods to handle Musique/2Wiki-specific operations.
"""
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Dict
def get_basename(self, file_name: str):
base, _ = os.path.splitext(os.path.basename(file_name))
return base
def load_data(self, input: Input, **kwargs) -> List[Output]:
"""
Loads data from a Musique/2Wiki corpus file or JSON string and returns it as a list of dictionaries.
This method reads Musique/2Wiki corpus data from a file or parses a JSON string and returns it as a list of dictionaries.
If the input is a file path, it reads the file; if the input is a JSON string, it parses the string.
Args:
input (Input): The Musique/2Wiki corpus file path or JSON string to load.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of dictionaries, where each dictionary represents a Musique/2Wiki item.
"""
if os.path.exists(input):
with open(input, "r") as f:
corpus = json.load(f)
else:
corpus = json.loads(input)
data = []
for idx, item in enumerate(corpus):
title = item["title"]
content = item["text"]
data.append(
{
"id": f"{title}#{idx}",
"name": title,
"content": content,
}
)
return data

View File

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import re
from typing import List
from kag.interface import ScannerABC
from knext.common.base.runnable import Input, Output
@ScannerABC.register("dir")
@ScannerABC.register("dir_file_scanner")
class DirectoryScanner(ScannerABC):
"""
A class for reading files from a directory based on a specified file pattern or suffix, inheriting from `ScannerABC`.
It can be used in conjunction with the parsers such as PDF/MarkDown parser to convert files into Chunks.
This class is responsible for reading files from a directory and returning a list of file paths that match the specified file pattern/suffix.
It inherits from `ScannerABC` and overrides the necessary methods to handle directory-specific operations.
"""
def __init__(
self,
file_pattern: str = None,
file_suffix: str = None,
rank: int = 0,
world_size: int = 1,
):
"""
Initializes the DirectoryScanner with the specified file pattern, file suffix, rank, and world size.
Args:
file_pattern (str, optional): The regex pattern to match file names. Defaults to None.
file_suffix (str, optional): The file suffix to match if `file_pattern` is not provided. Defaults to None.
rank (int, optional): The rank of the current worker. Defaults to 0.
world_size (int, optional): The total number of workers. Defaults to 1.
"""
super().__init__(rank=rank, world_size=world_size)
if file_pattern is None:
if file_suffix:
file_pattern = f".*{file_suffix}$"
else:
file_pattern = r".*txt$"
self.file_pattern = re.compile(file_pattern)
@property
def input_types(self) -> Input:
return str
@property
def output_types(self) -> Output:
return str
def find_files_by_regex(self, directory):
"""
Finds files in the specified directory that match the file pattern.
Args:
directory (str): The directory to search for files.
Returns:
List[str]: A list of file paths that match the file pattern.
"""
matched_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if self.file_pattern.match(file):
file_path = os.path.join(root, file)
matched_files.append(file_path)
return matched_files
def load_data(self, input: Input, **kwargs) -> List[Output]:
"""
Loads data by finding files in the specified directory that match the file pattern.
This method searches the directory specified by the input and returns a list of file paths that match the file pattern.
Args:
input (Input): The directory to search for files.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of file paths that match the file pattern.
"""
return self.find_files_by_regex(input)

View File

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import List
from kag.interface import ScannerABC
from kag.common.conf import KAG_PROJECT_CONF
from knext.common.base.runnable import Input, Output
@ScannerABC.register("file")
@ScannerABC.register("file_scanner")
class FileScanner(ScannerABC):
"""
A class for reading single file and returning the path, inheriting from `ScannerABC`.
This class is responsible for reading SINGLE file and returning the path as a list of strings.
It inherits from `ScannerABC` and overrides the necessary methods to handle file-specific operations.
"""
@property
def input_types(self) -> Input:
return str
@property
def output_types(self) -> Output:
return str
def load_data(self, input: Input, **kwargs) -> List[Output]:
"""
Loads data by returning the input file path as a list of strings.
This method takes the input file path and returns it as a list containing the file path.
Args:
input (Input): The file path to load.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list containing the input file path.
"""
if input.startswith("http://") or input.startswith("https://"):
from kag.common.utils import download_from_http
local_file_path = os.path.join(KAG_PROJECT_CONF.ckpt_dir, "file_scanner")
if not os.path.exists(local_file_path):
os.makedirs(local_file_path)
local_file = os.path.join(local_file_path, os.path.basename(input))
local_file = download_from_http(input, local_file)
return [local_file]
return [input]

View File

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import os
from typing import Union, Dict, List
from kag.interface import ScannerABC
from knext.common.base.runnable import Input, Output
@ScannerABC.register("json")
@ScannerABC.register("json_scanner")
class JSONScanner(ScannerABC):
"""
A class for reading JSON files or parsing JSON-formatted strings into a list of dictionaries, inheriting from `ScannerABC`.
This class is responsible for reading JSON files or parsing JSON-formatted strings and converting them into a list of dictionaries.
It inherits from `ScannerABC` and overrides the necessary methods to handle JSON-specific operations.
Note: The JSON data must be a list of dictionaries.
"""
@property
def input_types(self) -> Input:
return str
@property
def output_types(self) -> Output:
return Dict
@staticmethod
def _read_from_file(file_path: str) -> Union[dict, list]:
"""
Reads JSON data from a file and returns it as a list of dictionaries.
Args:
file_path (str): The path to the JSON file.
Returns:
List[Dict]: The JSON data loaded from the file.
Raises:
ValueError: If there is an error reading the JSON from the file or if the file is not found.
"""
try:
with open(file_path, "r") as file:
return json.load(file)
except json.JSONDecodeError as e:
raise ValueError(f"Error reading JSON from file: {e}")
except FileNotFoundError as e:
raise ValueError(f"File not found: {e}")
@staticmethod
def _parse_json_string(json_string: str) -> Union[dict, list]:
"""
Parses a JSON string and returns it as a list of dictionaries.
Args:
json_string (str): The JSON string to parse.
Returns:
List[Dict]: The parsed JSON data.
Raises:
ValueError: If there is an error parsing the JSON string.
"""
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
raise ValueError(f"Error parsing JSON string: {e}")
def load_data(self, input: Input, **kwargs) -> List[Output]:
"""
Loads data from a JSON file or JSON string and returns it as a list of dictionaries.
This method reads JSON data from a file or parses a JSON string and returns it as a list of dictionaries.
If the input is a file path, it reads the file; if the input is a JSON string, it parses the string.
Args:
input (Input): The JSON file path or JSON string to load.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of dictionaries, where each dictionary represents a JSON object.
Raises:
ValueError: If there is an error reading the JSON data or if the input is not a valid JSON array or object.
"""
input = self.download_data(input)
try:
if os.path.exists(input):
corpus = self._read_from_file(input)
else:
corpus = self._parse_json_string(input)
except ValueError as e:
raise e
if not isinstance(corpus, (list, dict)):
raise ValueError("Expected input to be a JSON array or object")
if isinstance(corpus, dict):
corpus = [corpus]
return corpus

View File

@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import requests
from typing import Type, List, Union
# from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.interface import ScannerABC
from knext.common.base.runnable import Input, Output
@ScannerABC.register("yuque")
@ScannerABC.register("yuque_scanner")
class YuqueScanner(ScannerABC):
"""
A class for reading data from Yuque, a Chinese documentation platform, inheriting from `ScannerABC`.
This class is responsible for reading the Yuque knowledge base and return the urls of the documents it contains.
It can be used in conjunction with the Yuque parser to convert Yuque documents into Chunks.
It inherits from `ScannerABC` and overrides the necessary methods to handle Yuque-specific operations.
Args:
token (str): The authentication token for accessing Yuque API.
rank (int, optional): The rank of the current worker. Defaults to 0.
world_size (int, optional): The total number of workers. Defaults to 1.
"""
def __init__(self, token: str):
"""
Initializes the YuqueScanner with the specified token, rank, and world size.
Args:
token (str): The authentication token for accessing Yuque API.
rank (int, optional): The rank of the current worker. Defaults to 0.
world_size (int, optional): The total number of workers. Defaults to 1.
"""
super().__init__()
self.token = token
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return Union[str, List[str]]
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return str
def get_yuque_api_data(self, url):
"""
Fetches data from the Yuque API using the specified URL and authentication token.
Args:
url (str): The URL to fetch data from.
Returns:
dict: The JSON data returned by the Yuque API.
Raises:
HTTPError: If the API returns a bad response (4xx or 5xx).
"""
headers = {"X-Auth-Token": self.token}
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx)
return response.json()["data"] # Assuming the API returns JSON data
def load_data(self, input: Input, **kwargs) -> List[Output]:
"""
Loads data from the Yuque API and returns it as a list of document url strings.
This method fetches data from the Yuque API using the provided URL and converts it into a list of strings.
If the input is a single document url, it returns a list containing the token and URL.
If the input is a knowledge base, it returns a list of strings where each string contains the token and the URL of each document it contains.
Args:
input (Input): The URL to fetch data from.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of strings, where each string contains the token and the URL of each document.
"""
url = input
if isinstance(url, str):
data = self.get_yuque_api_data(url)
if isinstance(data, dict):
# for single yuque doc
return [f"{self.token}@{url}"]
output = []
for item in data:
slug = item["slug"]
output.append(os.path.join(url, slug))
return [f"{self.token}@{url}" for url in output]
else:
return [f"{self.token}@{x}" for x in url]

View File

@ -1,23 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.splitter.length_splitter import LengthSplitter
from kag.builder.component.splitter.semantic_splitter import SemanticSplitter
from kag.builder.component.splitter.pattern_splitter import PatternSplitter
from kag.builder.component.splitter.outline_splitter import OutlineSplitter
__all__ = [
"LengthSplitter",
"SemanticSplitter",
"PatternSplitter",
]

View File

@ -10,28 +10,52 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from abc import ABC
from typing import Type, List, Union
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SplitterABC
from kag.interface import SplitterABC
class BaseTableSplitter(SplitterABC):
"""
A base class for splitting table, inheriting from Splitter.
A base class for splitting table data into smaller chunks.
This class inherits from SplitterABC and provides the functionality to split table data
represented in markdown format into smaller chunks.
"""
def __init__(self):
super().__init__()
def split_table(self, org_chunk: Chunk, chunk_size: int = 2000, sep: str = "\n"):
"""
split markdown format table into smaller markdown table
Splits a markdown format table into smaller markdown tables.
Args:
org_chunk (Chunk): The original chunk containing the table data.
chunk_size (int): The maximum size of each smaller chunk. Defaults to 2000.
sep (str): The separator used to join the table rows. Defaults to "\n".
Returns:
List[Chunk]: A list of smaller chunks resulting from the split operation.
"""
try:
return self._split_table(org_chunk=org_chunk, chunk_size=chunk_size, sep=sep)
return self._split_table(
org_chunk=org_chunk, chunk_size=chunk_size, sep=sep
)
except Exception:
return None
def _split_table(self, org_chunk: Chunk, chunk_size: int = 2000, sep: str = "\n"):
"""
Internal method to split a markdown format table into smaller markdown tables.
Args:
org_chunk (Chunk): The original chunk containing the table data.
chunk_size (int): The maximum size of each smaller chunk. Defaults to 2000.
sep (str): The separator used to join the table rows. Defaults to "\n".
Returns:
List[Chunk]: A list of smaller chunks resulting from the split operation.
"""
output = []
content = org_chunk.content
table_start = content.find("|")
@ -56,6 +80,7 @@ class BaseTableSplitter(SplitterABC):
cur.append(row)
cur_len += len(row)
cur.append(content[table_end:])
if len(cur) > 0:
splitted.append(cur)
@ -66,7 +91,7 @@ class BaseTableSplitter(SplitterABC):
name=f"{org_chunk.name}#{idx}",
content=sep.join(sentences),
type=org_chunk.type,
**org_chunk.kwargs
**org_chunk.kwargs,
)
output.append(chunk)
return output

View File

@ -10,26 +10,41 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Type, List, Union
from typing import Type, List
from kag.interface import SplitterABC
from kag.builder.model.chunk import Chunk, ChunkTypeEnum
from kag.interface.builder.base import KAG_PROJECT_CONF
from kag.common.utils import generate_hash_id
from knext.common.base.runnable import Input, Output
from kag.builder.component.splitter.base_table_splitter import BaseTableSplitter
@SplitterABC.register("length")
@SplitterABC.register("length_splitter")
class LengthSplitter(BaseTableSplitter):
"""
A class for splitting text based on length, inheriting from Splitter.
A class for splitting text based on length.
This class inherits from BaseTableSplitter and provides the functionality to split text
into smaller chunks based on a specified length and window size. It also handles table data
by splitting it into smaller markdown tables.
Attributes:
split_length (int): The maximum length of each split chunk.
split_length (int): The maximum length of each chunk.
window_length (int): The length of the overlap between chunks.
"""
def __init__(self, split_length: int = 500, window_length: int = 100, **kwargs):
super().__init__(**kwargs)
self.split_length = int(split_length)
self.window_length = int(window_length)
def __init__(self, split_length: int = 500, window_length: int = 100):
"""
Initializes the LengthSplitter with the specified split length and window length.
Args:
split_length (int): The maximum length of each chunk. Defaults to 500.
window_length (int): The length of the overlap between chunks. Defaults to 100.
"""
super().__init__()
self.split_length = split_length
self.window_length = window_length
@property
def input_types(self) -> Type[Input]:
@ -39,37 +54,52 @@ class LengthSplitter(BaseTableSplitter):
def output_types(self) -> Type[Output]:
return Chunk
def chunk_breakdown(self, chunk):
chunks = self.logic_break(chunk)
if chunks:
res_chunks = []
for c in chunks:
res_chunks.extend(self.chunk_breakdown(c))
else:
res_chunks = self.slide_window_chunk(
chunk, self.split_length, self.window_length
)
return res_chunks
def logic_break(self, chunk):
return None
def split_sentence(self, content):
"""
Splits the given content into sentences based on delimiters.
Args:
content (str): The content to be split.
content (str): The content to be split into sentences.
Returns:
list: A list of sentences.
List[str]: A list of sentences.
"""
sentence_delimiters = ".。??!"
sentence_delimiters = ".。??!" if KAG_PROJECT_CONF.language == "en" else "。?!"
output = []
start = 0
for idx, char in enumerate(content):
if char in sentence_delimiters:
end = idx
tmp = content[start: end + 1].strip()
tmp = content[start : end + 1].strip()
if len(tmp) > 0:
output.append(tmp)
output.append(tmp.strip())
start = idx + 1
res = content[start:]
res = content[start:].strip()
if len(res) > 0:
output.append(res)
return output
def slide_window_chunk(
self,
org_chunk: Chunk,
chunk_size: int = 2000,
window_length: int = 300,
sep: str = "\n",
self,
org_chunk: Chunk,
chunk_size: int = 2000,
window_length: int = 300,
sep: str = "\n",
) -> List[Chunk]:
"""
Splits the content into chunks using a sliding window approach.
@ -84,7 +114,9 @@ class LengthSplitter(BaseTableSplitter):
List[Chunk]: A list of Chunk objects.
"""
if org_chunk.type == ChunkTypeEnum.Table:
table_chunks = self.split_table(org_chunk=org_chunk, chunk_size=chunk_size, sep=sep)
table_chunks = self.split_table(
org_chunk=org_chunk, chunk_size=chunk_size, sep=sep
)
if table_chunks is not None:
return table_chunks
content = self.split_sentence(org_chunk.content)
@ -112,38 +144,36 @@ class LengthSplitter(BaseTableSplitter):
output = []
for idx, sentences in enumerate(splitted):
chunk = Chunk(
id=f"{org_chunk.id}#{chunk_size}#{window_length}#{idx}#LEN",
id=generate_hash_id(f"{org_chunk.id}#{idx}"),
name=f"{org_chunk.name}",
content=sep.join(sentences),
type=org_chunk.type,
**org_chunk.kwargs
chunk_size=chunk_size,
window_length=window_length,
**org_chunk.kwargs,
)
output.append(chunk)
return output
def invoke(self, input: Chunk, **kwargs) -> List[Output]:
def _invoke(self, input: Chunk, **kwargs) -> List[Output]:
"""
Invokes the splitter on the given input chunk.
Invokes the splitting of the input chunk based on the specified length and window size.
Args:
input (Chunk): The input chunk to be split.
**kwargs: Additional keyword arguments.
input (Chunk): The chunk(s) to be split.
**kwargs: Additional keyword arguments, currently unused but kept for potential future expansion.
Returns:
List[Output]: A list of split chunks.
List[Output]: A list of Chunk objects resulting from the split operation.
"""
cutted = []
if isinstance(input,list):
if isinstance(input, list):
for item in input:
cutted.extend(
self.slide_window_chunk(
item, self.split_length, self.window_length
)
self.slide_window_chunk(item, self.split_length, self.window_length)
)
else:
cutted.extend(
self.slide_window_chunk(
input, self.split_length, self.window_length
)
self.slide_window_chunk(input, self.split_length, self.window_length)
)
return cutted

File diff suppressed because it is too large Load Diff

View File

@ -10,27 +10,37 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Type, List, Union
# flake8: noqa
import re
import os
from typing import Type, List, Union
from kag.builder.model.chunk import Chunk, ChunkTypeEnum
from kag.interface.builder.splitter_abc import SplitterABC
from kag.builder.model.chunk import Chunk
from kag.interface import SplitterABC
from kag.common.utils import generate_hash_id
from knext.common.base.runnable import Input, Output
@SplitterABC.register("pattern")
@SplitterABC.register("pattern_splitter")
class PatternSplitter(SplitterABC):
def __init__(self, pattern_dict: dict = None, chunk_cut_num=None):
"""
A class for splitting text content based on specified patterns and chunking strategies.
"""
def __init__(self, pattern_dict: dict = None, chunk_cut_num: int = None):
"""
pattern_dict:
{
"pattern": 匹配pattern,
"group": {
"header":1,
"name":2,
"content":3
}
}
Initializes the PatternSplitter with the given pattern dictionary and chunk cut number.
Args:
pattern_dict (dict, optional): A dictionary containing the pattern and group mappings.
Defaults to a predefined pattern if not provided.
Example:
{
"pattern": r"(\d+).([^0-9]+?)([^0-9第版].*?)(?=\d+\.|$)",
"group": {"header": 2, "name": 2, "content": 0}
}
chunk_cut_num (int, optional): The number of characters to cut chunks into. Defaults to None.
"""
super().__init__()
if pattern_dict is None:
@ -53,6 +63,15 @@ class PatternSplitter(SplitterABC):
return List[Chunk]
def split_sentence(self, content):
"""
Splits the given content into sentences based on delimiters.
Args:
content (str): The content to be split into sentences.
Returns:
List[str]: A list of sentences extracted from the content.
"""
sentence_delimiters = "。??!;\n"
output = []
start = 0
@ -76,7 +95,19 @@ class PatternSplitter(SplitterABC):
sep: str = "\n",
prefix: str = "SlideWindow",
) -> List[Chunk]:
"""
Splits the content into chunks using a sliding window approach.
Args:
content (Union[str, List[str]]): The content to be chunked.
chunk_size (int, optional): The maximum size of each chunk. Defaults to 2000.
window_length (int, optional): The length of the sliding window. Defaults to 300.
sep (str, optional): The separator to join sentences within a chunk. Defaults to "\n".
prefix (str, optional): The prefix to use for chunk names. Defaults to "SlideWindow".
Returns:
List[Chunk]: A list of Chunk objects representing the chunked content.
"""
if isinstance(content, str):
content = self.split_sentence(content)
splitted = []
@ -103,7 +134,7 @@ class PatternSplitter(SplitterABC):
for idx, sentences in enumerate(splitted):
chunk_name = f"{prefix}#{idx}"
chunk = Chunk(
id=Chunk.generate_hash_id(chunk_name),
id=generate_hash_id(chunk_name),
name=chunk_name,
content=sep.join(sentences),
)
@ -114,6 +145,15 @@ class PatternSplitter(SplitterABC):
self,
chunk: Chunk,
) -> List[Chunk]:
"""
Splits the given chunk into smaller chunks based on the pattern and chunk cut number.
Args:
chunk (Chunk): The chunk to be split.
Returns:
List[Chunk]: A list of smaller Chunk objects.
"""
text = chunk.content
pattern = re.compile(self.pattern, re.DOTALL)
@ -127,7 +167,7 @@ class PatternSplitter(SplitterABC):
chunk = Chunk(
chunk_header=match.group(self.group["header"]),
name=match.group(self.group["name"]),
id=Chunk.generate_hash_id(match.group(self.group["content"])),
id=generate_hash_id(match.group(self.group["content"])),
content=match.group(self.group["content"]),
)
chunk = [chunk]
@ -145,43 +185,16 @@ class PatternSplitter(SplitterABC):
return chunks
def invoke(self, input: Chunk, **kwargs) -> List[Output]:
def _invoke(self, input: Chunk, **kwargs) -> List[Output]:
"""
Invokes the chunk splitting process on the given input.
Args:
input (Chunk): The input chunk to be processed.
**kwargs: Additional keyword arguments, currently unused but kept for potential future expansion.
Returns:
List[Output]: A list of output chunks.
"""
chunks = self.chunk_split(input)
return chunks
def to_rest(self):
pass
@classmethod
def from_rest(cls, rest_model):
pass
class LayeredPatternSpliter(PatternSplitter):
pass
def _test():
pattern_dict = {
"pattern": r"(\d+)\.([^0-9]+?)([^0-9第版].*?)(?=\d+\.|$)",
"group": {"header": 2, "name": 2, "content": 0},
}
ds = PatternSplitter(pattern_dict=pattern_dict)
from kag.builder.component.reader.pdf_reader import PDFReader
reader = PDFReader()
file_path = os.path.dirname(__file__)
test_file_path = os.path.join(file_path, "../../../../tests/builder/data/aiwen.pdf")
pre_output = reader._handle(test_file_path)
handle_input = pre_output[0]
handle_result = ds._handle(handle_input)
print("handle_result", handle_result)
return handle_result
if __name__ == "__main__":
res = _test()
print(res)

View File

@ -10,41 +10,56 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import os
import re
from typing import List, Type
from kag.interface.builder import SplitterABC
from kag.interface import SplitterABC
from kag.builder.prompt.semantic_seg_prompt import SemanticSegPrompt
from kag.builder.model.chunk import Chunk
from kag.interface import LLMClient
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.utils import generate_hash_id
from knext.common.base.runnable import Input, Output
from kag.common.llm.client.llm_client import LLMClient
logger = logging.getLogger(__name__)
@SplitterABC.register("semantic")
@SplitterABC.register("semantic_splitter")
class SemanticSplitter(SplitterABC):
"""
A class for semantically splitting text into smaller chunks based on the content's structure and meaning.
Inherits from the Splitter class.
Inherits from the SplitterABC class.
Attributes:
kept_char_pattern (re.Pattern): Regex pattern to match Chinese/ASCII characters.
split_length (int): The maximum length of each chunk after splitting.
llm_client (LLMClient): Instance of LLMClient initialized with `model` config.
semantic_seg_op (SemanticSegPrompt): Instance of SemanticSegPrompt for semantic segmentation.
"""
def __init__(self, split_length: int = 1000, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
llm: LLMClient,
kept_char_pattern: str = None,
split_length: int = 1000,
):
"""
Initializes the SemanticSplitter with the given LLMClient, kept character pattern, and split length.
Args:
llm (LLMClient): Instance of LLMClient initialized with `model` config.
kept_char_pattern (str, optional): Regex pattern to match Chinese/ASCII characters.
Defaults to a predefined pattern if not provided.
split_length (int, optional): The maximum length of each chunk after splitting. Defaults to 1000.
**kwargs: Additional keyword arguments to be passed to the superclass.
"""
super().__init__()
# Chinese/ASCII characters
self.kept_char_pattern = re.compile(
r"[^\u4e00-\u9fa5\u3000-\u303F\uFF01-\uFF0F\uFF1A-\uFF20\uFF3B-\uFF40\uFF5B-\uFF65\x00-\x7F]+"
)
self.split_length = int(split_length)
self.llm = self._init_llm()
language = os.getenv("KAG_PROMPT_LANGUAGE", "zh")
self.semantic_seg_op = SemanticSegPrompt(language)
if kept_char_pattern is None:
self.kept_char_pattern = re.compile(
r"[^\u4e00-\u9fa5\u3000-\u303F\uFF01-\uFF0F\uFF1A-\uFF20\uFF3B-\uFF40\uFF5B-\uFF65\x00-\x7F]+"
)
else:
self.kept_char_pattern = re.compile(kept_char_pattern)
self.split_length = split_length
self.llm = llm
self.semantic_seg_op = SemanticSegPrompt(KAG_PROJECT_CONF.language)
@property
def input_types(self) -> Type[Input]:
@ -103,6 +118,8 @@ class SemanticSplitter(SplitterABC):
"""
result = self.llm.invoke({"input": org_chunk.content}, self.semantic_seg_op)
splitted = self.parse_llm_output(org_chunk.content, result)
if len(splitted) == 0:
return [org_chunk]
logger.debug(f"splitted = {splitted}")
chunks = []
for idx, item in enumerate(splitted):
@ -113,30 +130,26 @@ class SemanticSplitter(SplitterABC):
name=f"{org_chunk.name}#{split_name}",
content=item["content"],
abstract=item["name"],
**org_chunk.kwargs
**org_chunk.kwargs,
)
chunks.append(chunk)
else:
print("chunk over size")
innerChunk = Chunk(
id=Chunk.generate_hash_id(item["content"]),
id=generate_hash_id(item["content"]),
name=f"{org_chunk.name}#{split_name}",
content=item["content"],
)
chunks.extend(
self.semantic_chunk(
innerChunk, chunk_size
)
)
chunks.extend(self.semantic_chunk(innerChunk, chunk_size))
return chunks
def invoke(self, input: Input, **kwargs) -> List[Output]:
def _invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the splitting process on the provided input.
Args:
input (Input): The input to be processed.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments, currently unused but kept for potential future expansion.
Returns:
List[Output]: A list of outputs generated from the input.

View File

@ -1,11 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -9,17 +9,18 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from collections import defaultdict
from typing import List
from tenacity import stop_after_attempt, retry
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from kag.common.vectorizer import Vectorizer
from kag.interface.builder.vectorizer_abc import VectorizerABC
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.utils import get_vector_field_name
from kag.interface import VectorizerABC, VectorizeModelABC
from knext.schema.client import SchemaClient
from knext.project.client import ProjectClient
from knext.schema.model.base import IndexTypeEnum
from knext.common.base.runnable import Input, Output
class EmbeddingVectorPlaceholder(object):
@ -43,22 +44,15 @@ class EmbeddingVectorManager(object):
def __init__(self):
self._placeholders = []
def _create_vector_field_name(self, property_key):
from kag.common.utils import to_snake_case
name = f"{property_key}_vector"
name = to_snake_case(name)
return "_" + name
def get_placeholder(self, properties, vector_field):
for property_key, property_value in properties.items():
field_name = self._create_vector_field_name(property_key)
field_name = get_vector_field_name(property_key)
if field_name != vector_field:
continue
if not property_value:
return None
if not isinstance(property_value, str):
message = f"property {property_key!r} must be string to generate embedding vector"
message = f"property {property_key!r} must be string to generate embedding vector, got {property_value} with type {type(property_value)}"
raise RuntimeError(message)
num = len(self._placeholders)
placeholder = EmbeddingVectorPlaceholder(
@ -78,11 +72,10 @@ class EmbeddingVectorManager(object):
return text_batch
def _generate_vectors(self, vectorizer, text_batch, batch_size=32):
if isinstance(text_batch, str):
text_batch = [text_batch]
texts = list(text_batch)
if not texts:
return []
if len(texts) % batch_size == 0:
n_batchs = len(texts) // batch_size
else:
@ -99,9 +92,9 @@ class EmbeddingVectorManager(object):
for placeholder in placeholders:
placeholder._embedding_vector = vector
def batch_generate(self, vectorizer):
def batch_generate(self, vectorizer, batch_size=32):
text_batch = self._get_text_batch()
vectors = self._generate_vectors(vectorizer, text_batch)
vectors = self._generate_vectors(vectorizer, text_batch, batch_size)
self._fill_vectors(vectors, text_batch)
def patch(self):
@ -115,7 +108,7 @@ class EmbeddingVectorGenerator(object):
self._extra_labels = extra_labels
self._vector_index_meta = vector_index_meta or {}
def batch_generate(self, node_batch):
def batch_generate(self, node_batch, batch_size=32):
manager = EmbeddingVectorManager()
vector_index_meta = self._vector_index_meta
for node_item in node_batch:
@ -132,41 +125,49 @@ class EmbeddingVectorGenerator(object):
placeholder = manager.get_placeholder(properties, vector_field)
if placeholder is not None:
properties[vector_field] = placeholder
manager.batch_generate(self._vectorizer)
manager.batch_generate(self._vectorizer, batch_size)
manager.patch()
@VectorizerABC.register("batch")
@VectorizerABC.register("batch_vectorizer")
class BatchVectorizer(VectorizerABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.project_id = self.project_id or os.getenv("KAG_PROJECT_ID")
self._init_graph_store()
self.vec_meta = self._init_vec_meta()
self.vectorizer = Vectorizer.from_config(self.vectorizer_config)
"""
A class for generating embedding vectors for node attributes in a SubGraph in batches.
def _init_graph_store(self):
This class inherits from VectorizerABC and provides the functionality to generate embedding vectors
for node attributes in a SubGraph in batches. It uses a specified vectorization model and processes
the nodes of a specified batch size.
Attributes:
project_id (int): The ID of the project associated with the SubGraph.
vec_meta (defaultdict): Metadata for vector fields in the SubGraph.
vectorize_model (VectorizeModelABC): The model used for generating embedding vectors.
batch_size (int): The size of the batches in which to process the nodes.
"""
def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32):
"""
Initializes the Graph Store client.
This method retrieves the graph store configuration from environment variables and the project ID.
It then fetches the project configuration using the project ID and updates the graph store configuration
with any additional settings from the project. Finally, it creates and initializes the graph store client
using the updated configuration.
Initializes the BatchVectorizer with the specified vectorization model and batch size.
Args:
project_id (str): The id of project.
Returns:
GraphStore
vectorize_model (VectorizeModelABC): The model used for generating embedding vectors.
batch_size (int): The size of the batches in which to process the nodes. Defaults to 32.
"""
graph_store_config = eval(os.getenv("KAG_GRAPH_STORE", "{}"))
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
config = ProjectClient().get_config(self.project_id)
graph_store_config.update(config.get("graph_store", {}))
vectorizer_config.update(config.get("vectorizer", {}))
self.vectorizer_config = vectorizer_config
super().__init__()
self.project_id = KAG_PROJECT_CONF.project_id
# self._init_graph_store()
self.vec_meta = self._init_vec_meta()
self.vectorize_model = vectorize_model
self.batch_size = batch_size
def _init_vec_meta(self):
"""
Initializes the vector metadata for the SubGraph.
Returns:
defaultdict: Metadata for vector fields in the SubGraph.
"""
vec_meta = defaultdict(list)
schema_client = SchemaClient(project_id=self.project_id)
spg_types = schema_client.load()
@ -176,32 +177,31 @@ class BatchVectorizer(VectorizerABC):
IndexTypeEnum.Vector,
IndexTypeEnum.TextAndVector,
]:
vec_meta[type_name].append(
self._create_vector_field_name(prop_name)
)
vec_meta[type_name].append(get_vector_field_name(prop_name))
return vec_meta
def _create_vector_field_name(self, property_key):
from kag.common.utils import to_snake_case
@retry(stop=stop_after_attempt(3))
def _generate_embedding_vectors(self, input_subgraph: SubGraph) -> SubGraph:
"""
Generates embedding vectors for the nodes in the input SubGraph.
name = f"{property_key}_vector"
name = to_snake_case(name)
return "_" + name
Args:
input_subgraph (SubGraph): The SubGraph for which to generate embedding vectors.
def _generate_embedding_vectors(
self, vectorizer: Vectorizer, input: SubGraph
) -> SubGraph:
Returns:
SubGraph: The modified SubGraph with generated embedding vectors.
"""
node_list = []
node_batch = []
for node in input.nodes:
for node in input_subgraph.nodes:
if not node.id or not node.name:
continue
properties = {"id": node.id, "name": node.name}
properties.update(node.properties)
node_list.append((node, properties))
node_batch.append((node.label, properties.copy()))
generator = EmbeddingVectorGenerator(vectorizer, self.vec_meta)
generator.batch_generate(node_batch)
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
generator.batch_generate(node_batch, self.batch_size)
for (node, properties), (_node_label, new_properties) in zip(
node_list, node_batch
):
@ -209,8 +209,18 @@ class BatchVectorizer(VectorizerABC):
if key in new_properties and new_properties[key] == value:
del new_properties[key]
node.properties.update(new_properties)
return input
return input_subgraph
def invoke(self, input: Input, **kwargs) -> List[Output]:
modified_input = self._generate_embedding_vectors(self.vectorizer, input)
def _invoke(self, input_subgraph: Input, **kwargs) -> List[Output]:
"""
Invokes the generation of embedding vectors for the input SubGraph.
Args:
input_subgraph (Input): The SubGraph for which to generate embedding vectors.
**kwargs: Additional keyword arguments, currently unused but kept for potential future expansion.
Returns:
List[Output]: A list containing the modified SubGraph with generated embedding vectors.
"""
modified_input = self._generate_embedding_vectors(input_subgraph)
return [modified_input]

View File

@ -1,17 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.writer.kg_writer import KGWriter
__all__ = [
"KGWriter",
]

View File

@ -9,14 +9,15 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import logging
import os
from enum import Enum
from typing import Type, Dict, List
from knext.graph_algo.client import GraphAlgoClient
from knext.graph.client import GraphClient
from kag.builder.model.sub_graph import SubGraph
from kag.interface.builder.writer_abc import SinkWriterABC
from kag.interface import SinkWriterABC
from kag.common.conf import KAG_PROJECT_CONF
from knext.common.base.runnable import Input, Output
logger = logging.getLogger(__name__)
@ -27,19 +28,30 @@ class AlterOperationEnum(str, Enum):
Delete = "DELETE"
@SinkWriterABC.register("kg", as_default=True)
@SinkWriterABC.register("kg_writer", as_default=True)
class KGWriter(SinkWriterABC):
"""
A class that extends `SinkWriter` to handle writing data into a Neo4j knowledge graph.
A class for writing SubGraphs to a Knowledge Graph (KG) storage.
This class is responsible for configuring the graph store based on environment variables and
an optional project ID, initializing the Neo4j client, and setting up the schema.
It also manages semantic indexing and multi-threaded operations.
This class inherits from SinkWriterABC and provides the functionality to write SubGraphs
to a Knowledge Graph storage system. It supports operations like upsert and delete.
"""
def __init__(self, project_id: str = None, **kwargs):
def __init__(self, project_id: int = None, **kwargs):
"""
Initializes the KGWriter with the specified project ID.
Args:
project_id (int): The ID of the project associated with the KG. Defaults to None.
**kwargs: Additional keyword arguments passed to the superclass.
"""
super().__init__(**kwargs)
self.project_id = project_id or os.getenv("KAG_PROJECT_ID")
self.client = GraphAlgoClient(project_id=project_id)
if project_id is None:
self.project_id = KAG_PROJECT_CONF.project_id
else:
self.project_id = project_id
self.client = GraphClient(project_id=project_id)
@property
def input_types(self) -> Type[Input]:
@ -49,25 +61,84 @@ class KGWriter(SinkWriterABC):
def output_types(self) -> Type[Output]:
return None
def format_label(self, label: str):
"""
Formats the label by adding the project namespace if it is not already present.
Args:
label (str): The label to be formatted.
Returns:
str: The formatted label.
"""
namespace = KAG_PROJECT_CONF.namespace
if label.split(".")[0] == namespace:
return label
return f"{namespace}.{label}"
def standarlize_graph(self, graph):
for node in graph.nodes:
node.label = self.format_label(node.label)
for edge in graph.edges:
edge.from_type = self.format_label(edge.from_type)
edge.to_type = self.format_label(edge.to_type)
for node in graph.nodes:
for k, v in node.properties.items():
if k.startswith("_"):
continue
if not isinstance(v, str):
node.properties[k] = json.dumps(v, ensure_ascii=False)
for edge in graph.edges:
for k, v in edge.properties.items():
if k.startswith("_"):
continue
if not isinstance(v, str):
edge.properties[k] = json.dumps(v, ensure_ascii=False)
return graph
def invoke(
self, input: Input, alter_operation: str = AlterOperationEnum.Upsert, lead_to_builder: bool = False
self,
input: Input,
alter_operation: str = AlterOperationEnum.Upsert,
lead_to_builder: bool = False,
**kwargs,
) -> List[Output]:
"""
Invokes the specified operation (upsert or delete) on the graph store.
Args:
input (Input): The input object representing the subgraph to operate on.
alter_operation (str): The type of operation to perform (Upsert or Delete).
lead_to_builder (str): enable lead to event infer builder
alter_operation (str): The type of operation to perform (Upsert or Delete). Defaults to Upsert.
lead_to_builder (bool): Enable lead to event infer builder. Defaults to False.
Returns:
List[Output]: A list of output objects (currently always [None]).
"""
self.client.write_graph(sub_graph=input.to_dict(), operation=alter_operation, lead_to_builder=lead_to_builder)
return [None]
input = self.standarlize_graph(input)
logger.debug(f"final graph to write: {input}")
self.client.write_graph(
sub_graph=input.to_dict(),
operation=alter_operation,
lead_to_builder=lead_to_builder,
)
return [input]
def _handle(self, input: Dict, alter_operation: str, **kwargs):
"""The calling interface provided for SPGServer."""
"""
The calling interface provided for SPGServer.
Args:
input (Dict): The input dictionary representing the subgraph to operate on.
alter_operation (str): The type of operation to perform (Upsert or Delete).
**kwargs: Additional keyword arguments.
Returns:
None: This method currently returns None.
"""
_input = self.input_types.from_dict(input)
_output = self.invoke(_input, alter_operation)
_output = self.invoke(_input, alter_operation) # noqa
return None

View File

@ -9,149 +9,182 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import importlib
import os
from kag.builder.component import SPGTypeMapping, KGWriter
from kag.builder.component.extractor import KAGExtractor
from kag.builder.component.splitter import LengthSplitter
from kag.builder.component.vectorizer.batch_vectorizer import BatchVectorizer
from knext.common.base.chain import Chain
from knext.builder.builder_chain_abc import BuilderChainABC
from concurrent.futures import ThreadPoolExecutor, as_completed
from kag.interface import (
ReaderABC,
MappingABC,
ExtractorABC,
SplitterABC,
VectorizerABC,
PostProcessorABC,
SinkWriterABC,
KAGBuilderChain,
)
from kag.common.utils import generate_hash_id
logger = logging.getLogger(__name__)
def get_reader(file_path: str):
file = os.path.basename(file_path)
suffix = file.split(".")[-1]
assert suffix.lower() in READER_MAPPING, f"{suffix} is not supported. Supported suffixes are: {list(READER_MAPPING.keys())}"
reader_path = READER_MAPPING.get(suffix.lower())
mod_path, class_name = reader_path.rsplit('.', 1)
module = importlib.import_module(mod_path)
reader_class = getattr(module, class_name)
return reader_class
READER_MAPPING = {
"csv": "kag.builder.component.reader.csv_reader.CSVReader",
"json": "kag.builder.component.reader.json_reader.JSONReader",
"txt": "kag.builder.component.reader.txt_reader.TXTReader",
"pdf": "kag.builder.component.reader.pdf_reader.PDFReader",
"docx": "kag.builder.component.reader.docx_reader.DocxReader",
"md": "kag.builder.component.reader.markdown_reader.MarkdownReader",
}
class DefaultStructuredBuilderChain(BuilderChainABC):
@KAGBuilderChain.register("structured")
@KAGBuilderChain.register("structured_builder_chain")
class DefaultStructuredBuilderChain(KAGBuilderChain):
"""
A class representing a default SPG builder chain, used to import structured data based on schema definitions
Steps:
0. Initializing by a give SpgType name, which indicates the target of import.
1. SourceReader: Reading structured dicts from a given file.
2. SPGTypeMapping: Mapping source fields to the properties of target type, and assemble a sub graph.
By default, the same name mapping is used, which means importing the source field into a property with the same name.
3. KGWriter: Writing sub graph into KG storage.
Attributes:
spg_type_name (str): The name of the SPG type.
A class representing a default SPG builder chain, used to import structured data based on schema definitions.
It consists of a mapping component, a writer component, and an optional vectorizer component.
"""
def __init__(self, spg_type_name: str, **kwargs):
super().__init__(**kwargs)
self.spg_type_name = spg_type_name
def __init__(
self,
mapping: MappingABC,
writer: SinkWriterABC,
vectorizer: VectorizerABC = None,
):
"""
Initializes the DefaultStructuredBuilderChain instance.
Args:
mapping (MappingABC): The mapping component to be used.
writer (SinkWriterABC): The writer component to be used.
vectorizer (VectorizerABC, optional): The vectorizer component to be used. Defaults to None.
"""
self.mapping = mapping
self.writer = writer
self.vectorizer = vectorizer
def build(self, **kwargs):
"""
Builds the processing chain for the SPG.
Construct the builder chain by connecting the mapping, vectorizer (if available), and writer components.
Args:
**kwargs: Additional keyword arguments.
Returns:
chain: The constructed processing chain.
KAGBuilderChain: The constructed builder chain.
"""
file_path = kwargs.get("file_path")
source = get_reader(file_path)(output_type="Dict")
mapping = SPGTypeMapping(spg_type_name=self.spg_type_name)
sink = KGWriter()
if self.vectorizer:
chain = self.mapping >> self.vectorizer >> self.writer
else:
chain = self.mapping >> self.writer
chain = source >> mapping >> sink
return chain
def invoke(self, file_path, max_workers=10, **kwargs):
logger.info(f"begin processing file_path:{file_path}")
"""
Invokes the processing chain with the given file path and optional parameters.
# def get_component_with_ckpts(self):
# return [
# self.mapping,
# self.vectorizer,
# self.writer,
# ]
Args:
file_path (str): The path to the input file.
max_workers (int, optional): The maximum number of workers. Defaults to 10.
**kwargs: Additional keyword arguments.
Returns:
The result of invoking the processing chain.
"""
return super().invoke(file_path=file_path, max_workers=max_workers, **kwargs)
# def close_checkpointers(self):
# for node in self.get_component_with_ckpts():
# if node and hasattr(node, "checkpointer"):
# node.checkpointer.close()
class DefaultUnstructuredBuilderChain(BuilderChainABC):
@KAGBuilderChain.register("unstructured")
@KAGBuilderChain.register("unstructured_builder_chain")
class DefaultUnstructuredBuilderChain(KAGBuilderChain):
"""
A class representing a default KAG builder chain, used to extract graph from documents and import unstructured data.
Steps:
0. Initializing.
1. SourceReader: Reading chunks from a given file.
2. LengthSplitter: Splitting chunk to smaller chunks. The chunk size can be adjusted through parameters.
3. KAGExtractor: Extracting entities and relations from chunks, and assembling a sub graph.
By default,the extraction process includes NER and SPO Extraction.
4. KGWriter: Writing sub graph into KG storage.
A class representing a default unstructured builder chain, used to build a knowledge graph from unstructured text data such as txt and pdf files.
It consists of a reader, splitter, extractor, vectorizer, optional post-processor, and writer components.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, **kwargs) -> Chain:
def __init__(
self,
reader: ReaderABC,
splitter: SplitterABC,
extractor: ExtractorABC = None,
vectorizer: VectorizerABC = None,
writer: SinkWriterABC = None,
post_processor: PostProcessorABC = None,
):
"""
Builds the processing chain for the KAG.
Initializes the DefaultUnstructuredBuilderChain instance.
Args:
reader (ReaderABC): The reader component to be used.
splitter (SplitterABC): The splitter component to be used.
extractor (ExtractorABC): The extractor component to be used.
vectorizer (VectorizerABC): The vectorizer component to be used.
writer (SinkWriterABC): The writer component to be used.
post_processor (PostProcessorABC, optional): The post-processor component to be used. Defaults to None.
"""
self.reader = reader
self.splitter = splitter
self.extractor = extractor
self.vectorizer = vectorizer
self.post_processor = post_processor
self.writer = writer
def build(self, **kwargs):
pass
def invoke(self, input_data, max_workers=10, **kwargs):
"""
Invokes the builder chain to process the input file.
Args:
file_path: The path to the input file to be processed.
max_workers (int, optional): The maximum number of threads to use. Defaults to 10.
**kwargs: Additional keyword arguments.
Returns:
chain: The constructed processing chain.
List: The final output from the builder chain.
"""
file_path = kwargs.get("file_path")
split_length = kwargs.get("split_length")
window_length = kwargs.get("window_length")
source = get_reader(file_path)()
splitter = LengthSplitter(split_length, window_length)
extractor = KAGExtractor()
vectorizer = BatchVectorizer()
sink = KGWriter()
chain = source >> splitter >> extractor >> vectorizer >> sink
return chain
def execute_node(node, node_input, **kwargs):
if not isinstance(node_input, list):
node_input = [node_input]
node_output = []
for item in node_input:
node_output.extend(node.invoke(item, **kwargs))
return node_output
def invoke(self, file_path: str, split_length: int = 500, window_length: int = 100, max_workers=10, **kwargs):
logger.info(f"begin processing file_path:{file_path}")
"""
Invokes the processing chain with the given file path and optional parameters.
def run_extract(chunk):
flow_data = [chunk]
input_key = chunk.hash_key
for node in [
self.extractor,
self.vectorizer,
self.post_processor,
self.writer,
]:
if node is None:
continue
flow_data = execute_node(node, flow_data, key=input_key)
return {input_key: flow_data[0]}
Args:
file_path (str): The path to the input file.
split_length (int, optional): The length at which the file should be split. Defaults to 500.
window_length (int, optional): The length of the processing window. Defaults to 100.
max_workers (int, optional): The maximum number of worker threads. Defaults to 10.
reader_output = self.reader.invoke(input_data, key=generate_hash_id(input_data))
splitter_output = []
**kwargs: Additional keyword arguments.
for chunk in reader_output:
splitter_output.extend(self.splitter.invoke(chunk, key=chunk.hash_key))
Returns:
The result of invoking the processing chain.
"""
return super().invoke(file_path=file_path, max_workers=max_workers, split_length=window_length, window_length=window_length, **kwargs)
processed_chunk_keys = kwargs.get("processed_chunk_keys", set())
filtered_chunks = []
processed = 0
for chunk in splitter_output:
if chunk.hash_key not in processed_chunk_keys:
filtered_chunks.append(chunk)
else:
processed += 1
logger.debug(
f"Total chunks: {len(splitter_output)}. Checkpointed: {processed}, Pending: {len(filtered_chunks)}."
)
result = []
with ThreadPoolExecutor(max_workers) as executor:
futures = [executor.submit(run_extract, chunk) for chunk in filtered_chunks]
from tqdm import tqdm
for inner_future in tqdm(
as_completed(futures),
total=len(futures),
desc="KAG Extraction From Chunk",
position=1,
leave=False,
):
ret = inner_future.result()
result.append(ret)
return result

View File

@ -9,9 +9,10 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import hashlib
from enum import Enum
from typing import Dict, Any
from kag.common.utils import generate_hash_id
import json
class ChunkTypeEnum(str, Enum):
@ -26,29 +27,27 @@ class Chunk:
name: str,
content: str,
type: ChunkTypeEnum = ChunkTypeEnum.Text,
**kwargs
**kwargs,
):
self.id = id
self.name = name
self.type = type
self.content = content
self.kwargs = kwargs
for key, value in kwargs.items():
setattr(self, key, value)
@staticmethod
def generate_hash_id(value):
if isinstance(value, str):
value = value.encode("utf-8")
hasher = hashlib.sha256()
hasher.update(value)
return hasher.hexdigest()
@property
def hash_key(self):
return generate_hash_id(f"{self.id}{self.name}{self.content}")
def __str__(self):
tmp = {
"id": self.id,
"name": self.name,
"content": self.content
if len(self.content) <= 64
else self.content[:64] + " ...",
"content": (
self.content if len(self.content) <= 64 else self.content[:64] + " ..."
),
}
return f"<Chunk>: {tmp}"
@ -59,7 +58,9 @@ class Chunk:
"id": self.id,
"name": self.name,
"content": self.content,
"type": self.type.value if isinstance(self.type, ChunkTypeEnum) else self.type,
"type": (
self.type.value if isinstance(self.type, ChunkTypeEnum) else self.type
),
"properties": self.kwargs,
}
@ -72,3 +73,10 @@ class Chunk:
type=input_.get("type"),
**input_.get("properties", {}),
)
def dump_chunks(chunks, **kwargs):
if kwargs.get("output_path"):
with open(kwargs.get("output_path"), "w") as f:
for chunk in chunks:
f.write(json.dumps(chunk.to_dict(), ensure_ascii=False) + "\n")

View File

@ -23,145 +23,165 @@ class SPGRecord:
"""Data structure in operator, used to store entity information."""
def __init__(self, spg_type_name: SPGTypeName):
"""
Initializes a new instance of the SPGRecord class.
Args:
spg_type_name (SPGTypeName): The type name of the SPG entity.
"""
self._spg_type_name = spg_type_name
self._properties = {}
self._relations = {}
@property
def id(self) -> str:
"""
Gets the ID of the SPGRecord.
Returns:
str: The ID of the SPGRecord.
"""
return self.get_property("id", "")
@property
def name(self) -> str:
"""
Gets the name of the SPGRecord.
Returns:
str: The name of the SPGRecord.
"""
return self.get_property("name", self.id)
@property
def spg_type_name(self) -> SPGTypeName:
"""Gets the spg_type_name of this SPGRecord. # noqa: E501
"""
Gets the SPG type name of this SPGRecord.
:return: The spg_type_name of this SPGRecord. # noqa: E501
:rtype: str
Returns:
SPGTypeName: The SPG type name of this SPGRecord.
"""
return self._spg_type_name
@spg_type_name.setter
def spg_type_name(self, spg_type_name: SPGTypeName):
"""Sets the spg_type_name of this SPGRecord.
"""
Sets the SPG type name of this SPGRecord.
:param spg_type_name: The spg_type_name of this SPGRecord. # noqa: E501
:type: str
Args:
spg_type_name (SPGTypeName): The SPG type name of this SPGRecord.
"""
self._spg_type_name = spg_type_name
@property
def properties(self) -> Dict[PropertyName, str]:
"""Gets the properties of this SPGRecord. # noqa: E501
"""
Gets the properties of this SPGRecord.
:return: The properties of this SPGRecord. # noqa: E501
:rtype: dict
Returns:
Dict[PropertyName, str]: The properties of this SPGRecord.
"""
return self._properties
@properties.setter
def properties(self, properties: Dict[PropertyName, str]):
"""Sets the properties of this SPGRecord.
"""
Sets the properties of this SPGRecord.
:param properties: The properties of this SPGRecord. # noqa: E501
:type: dict
Args:
properties (Dict[PropertyName, str]): The properties of this SPGRecord.
"""
self._properties = properties
@property
def relations(self) -> Dict[str, str]:
"""Gets the relations of this SPGRecord. # noqa: E501
"""
Gets the relations of this SPGRecord.
:return: The relations of this SPGRecord. # noqa: E501
:rtype: dict
Returns:
Dict[str, str]: The relations of this SPGRecord.
"""
return self._relations
@relations.setter
def relations(self, relations: Dict[str, str]):
"""Sets the properties of this SPGRecord.
"""
Sets the relations of this SPGRecord.
:param relations: The relations of this SPGRecord. # noqa: E501
:type: dict
Args:
relations (Dict[str, str]): The relations of this SPGRecord.
"""
self._relations = relations
def get_property(
self, property_name: PropertyName, default_value: str = None
) -> str:
"""Gets a property of this SPGRecord by name. # noqa: E501
"""
Gets a property of this SPGRecord by name.
Args:
property_name (PropertyName): The property name.
default_value (str, optional): If the property value is None, the default_value will be returned. Defaults to None.
:param property_name: The property name. # noqa: E501
:param default_value: If property value is None, the default_value will be return. # noqa: E501
:return: A property value. # noqa: E501
:rtype: str
Returns:
str: The property value.
"""
return self.properties.get(property_name, default_value)
def upsert_property(self, property_name: PropertyName, value: str):
"""Upsert a property of this SPGRecord. # noqa: E501
"""
Upserts a property of this SPGRecord.
:param property_name: The updated property name. # noqa: E501
:param value: The updated property value. # noqa: E501
:type: str
Args:
property_name (PropertyName): The updated property name.
value (str): The updated property value.
"""
self.properties[property_name] = value
return self
def append_property(self, property_name: PropertyName, value: str):
"""Append a property of this SPGRecord. # noqa: E501
"""
Appends a property of this SPGRecord.
:param property_name: The updated property name. # noqa: E501
:param value: The updated property value. # noqa: E501
:type: str
Args:
property_name (PropertyName): The updated property name.
value (str): The updated property value.
"""
property_value = self.get_property(property_name)
if property_value:
property_value_list = property_value.split(',')
property_value_list = property_value.split(",")
if value not in property_value_list:
self.properties[property_name] = property_value + ',' + value
self.properties[property_name] = property_value + "," + value
else:
self.properties[property_name] = value
return self
def upsert_properties(self, properties: Dict[PropertyName, str]):
"""Upsert properties of this SPGRecord. # noqa: E501
"""
Upserts properties of this SPGRecord.
:param properties: The updated properties. # noqa: E501
:type: dict
Args:
properties (Dict[PropertyName, str]): The updated properties.
"""
self.properties.update(properties)
return self
def remove_property(self, property_name: PropertyName):
"""Removes a property of this SPGRecord. # noqa: E501
"""
Removes a property of this SPGRecord.
:param property_name: The property name. # noqa: E501
:type: str
Args:
property_name (PropertyName): The property name.
"""
self.properties.pop(property_name)
return self
def remove_properties(self, property_names: List[PropertyName]):
"""Removes properties by given names. # noqa: E501
"""
Removes properties by given names.
:param property_names: A list of property names. # noqa: E501
:type: list
Args:
property_names (List[PropertyName]): A list of property names.
"""
for property_name in property_names:
self.properties.pop(property_name)
@ -173,37 +193,39 @@ class SPGRecord:
object_type_name: SPGTypeName,
default_value: str = None,
) -> str:
"""Gets a relation of this SPGRecord by name. # noqa: E501
"""
Gets a relation of this SPGRecord by name.
Args:
relation_name (RelationName): The relation name.
object_type_name (SPGTypeName): The object SPG type name.
default_value (str, optional): If the relation value is None, the default_value will be returned. Defaults to None.
:param relation_name: The relation name. # noqa: E501
:param object_type_name: The object SPG type name. # noqa: E501
:param default_value: If property value is None, the default_value will be return. # noqa: E501
:return: A relation value. # noqa: E501
:rtype: str
Returns:
str: The relation value.
"""
return self.relations.get(relation_name + "#" + object_type_name, default_value)
def upsert_relation(
self, relation_name: RelationName, object_type_name: SPGTypeName, value: str
):
"""Upsert a relation of this SPGRecord. # noqa: E501
"""
Upserts a relation of this SPGRecord.
:param relation_name: The updated relation name. # noqa: E501
:param object_type_name: The object SPG type name. # noqa: E501
:param value: The updated relation value. # noqa: E501
:type: str
Args:
relation_name (RelationName): The updated relation name.
object_type_name (SPGTypeName): The object SPG type name.
value (str): The updated relation value.
"""
self.relations[relation_name + "#" + object_type_name] = value
return self
def upsert_relations(self, relations: Dict[Tuple[RelationName, SPGTypeName], str]):
"""Upsert relations of this SPGRecord. # noqa: E501
"""
Upserts relations of this SPGRecord.
:param relations: The updated relations. # noqa: E501
:type: dict
Args:
relations (Dict[Tuple[RelationName, SPGTypeName], str]): The updated relations.
"""
for (relation_name, object_type_name), value in relations.items():
self.relations[relation_name + "#" + object_type_name] = value
@ -212,33 +234,43 @@ class SPGRecord:
def remove_relation(
self, relation_name: RelationName, object_type_name: SPGTypeName
):
"""Removes a relation of this SPGRecord. # noqa: E501
"""
Removes a relation of this SPGRecord.
:param relation_name: The relation name. # noqa: E501
:param object_type_name: The object SPG type name. # noqa: E501
:type: str
Args:
relation_name (RelationName): The relation name.
object_type_name (SPGTypeName): The object SPG type name.
"""
self.relations.pop(relation_name + "#" + object_type_name)
return self
def remove_relations(self, relation_names: List[Tuple[RelationName, SPGTypeName]]):
"""Removes relations by given names. # noqa: E501
:param relation_names: A list of relation names. # noqa: E501
:type: list
"""
for (relation_name, object_type_name) in relation_names:
Removes relations by given names.
Args:
relation_names (List[Tuple[RelationName, SPGTypeName]]): A list of relation names.
"""
for relation_name, object_type_name in relation_names:
self.relations.pop(relation_name + "#" + object_type_name)
return self
def to_str(self):
"""Returns the string representation of the model"""
"""
Returns the string representation of the model.
Returns:
str: The string representation of the model.
"""
return pprint.pformat(self.__dict__())
def to_dict(self):
"""Returns the model properties as a dict"""
"""
Returns the model properties as a dict.
Returns:
dict: The model properties as a dict.
"""
return {
"spgTypeName": self.spg_type_name,
@ -249,7 +281,12 @@ class SPGRecord:
}
def __dict__(self):
"""Returns this SPGRecord as a dict"""
"""
Returns this SPGRecord as a dict.
Returns:
dict: This SPGRecord as a dict.
"""
return {
"spgTypeName": self.spg_type_name,
"properties": self.properties,
@ -258,7 +295,15 @@ class SPGRecord:
@classmethod
def from_dict(cls, input: Dict[str, Any]):
"""Returns the model from a dict"""
"""
Returns the model from a dict.
Args:
input (Dict[str, Any]): The input dictionary.
Returns:
SPGRecord: The model from the input dictionary.
"""
spg_type_name = input.get("spgTypeName")
_cls = cls(spg_type_name)
properties = input.get("properties")
@ -272,5 +317,10 @@ class SPGRecord:
return _cls
def __repr__(self):
"""For `print` and `pprint`"""
"""
For `print` and `pprint`.
Returns:
str: The string representation of the model.
"""
return pprint.pformat(self.__dict__())

View File

@ -10,10 +10,11 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import pprint
import copy
from typing import Dict, List, Any
from knext.schema.client import BASIC_TYPES
from kag.builder.model.spg_record import SPGRecord
from knext.schema.client import BASIC_TYPES
from knext.schema.model.base import BaseSpgType
@ -41,14 +42,14 @@ class Node(object):
@staticmethod
def unique_key(spg_record):
return spg_record.spg_type_name + '_' + spg_record.get_property("name", "")
return spg_record.spg_type_name + "_" + spg_record.get_property("name", "")
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"label": self.label,
"properties": self.properties,
"properties": copy.deepcopy(self.properties),
}
@classmethod
@ -57,11 +58,15 @@ class Node(object):
_id=input["id"],
name=input["name"],
label=input["label"],
properties=input["properties"],
properties=input.get("properties", {}),
)
def __eq__(self, other):
return self.name == other.name and self.label == other.label and self.properties == other.properties
return (
self.name == other.name
and self.label == other.label
and self.properties == other.properties
)
class Edge(object):
@ -74,7 +79,12 @@ class Edge(object):
properties: Dict[str, str]
def __init__(
self, _id: str, from_node: Node, to_node: Node, label: str, properties: Dict[str, str]
self,
_id: str,
from_node: Node,
to_node: Node,
label: str,
properties: Dict[str, str],
):
self.from_id = from_node.id
self.from_type = from_node.label
@ -88,12 +98,19 @@ class Edge(object):
@classmethod
def from_spg_record(
cls, s_idx, subject_record: SPGRecord, o_idx, object_record: SPGRecord, label: str
cls,
s_idx,
subject_record: SPGRecord,
o_idx,
object_record: SPGRecord,
label: str,
):
from_node = Node.from_spg_record(s_idx, subject_record)
to_node = Node.from_spg_record(o_idx, object_record)
return cls(_id="", from_node=from_node, to_node=to_node, label=label, properties={})
return cls(
_id="", from_node=from_node, to_node=to_node, label=label, properties={}
)
def to_dict(self):
return {
@ -103,21 +120,35 @@ class Edge(object):
"fromType": self.from_type,
"toType": self.to_type,
"label": self.label,
"properties": self.properties,
"properties": copy.deepcopy(self.properties),
}
@classmethod
def from_dict(cls, input: Dict):
return cls(
_id=input["id"],
from_node=Node(_id=input["from"], name=input["from"],label=input["fromType"], properties={}),
to_node=Node(_id=input["to"], name=input["to"], label=input["toType"], properties={}),
from_node=Node(
_id=input["from"],
name=input["from"],
label=input["fromType"],
properties={},
),
to_node=Node(
_id=input["to"], name=input["to"], label=input["toType"], properties={}
),
label=input["label"],
properties=input["properties"],
properties=input.get("properties", {}),
)
def __eq__(self, other):
return self.from_id == other.from_id and self.to_id == other.to_id and self.label == other.label and self.properties == other.properties and self.from_type == other.from_type and self.to_type == other.to_type
return (
self.from_id == other.from_id
and self.to_id == other.to_id
and self.label == other.label
and self.properties == other.properties
and self.from_type == other.from_type
and self.to_type == other.to_type
)
class SubGraph(object):
@ -135,12 +166,18 @@ class SubGraph(object):
self.nodes.append(Node(_id=id, name=name, label=label, properties=properties))
return self
def add_edge(self, s_id: str, s_label: str, p: str, o_id: str, o_label: str, properties=None):
def add_edge(
self, s_id: str, s_label: str, p: str, o_id: str, o_label: str, properties=None
):
if not properties:
properties = dict()
s_node = Node(_id=s_id, name=s_id, label=s_label, properties={})
o_node = Node(_id=o_id, name=o_id, label=o_label, properties={})
self.edges.append(Edge(_id="", from_node=s_node, to_node=o_node, label=p, properties=properties))
self.edges.append(
Edge(
_id="", from_node=s_node, to_node=o_node, label=p, properties=properties
)
)
return self
def to_dict(self):
@ -152,7 +189,7 @@ class SubGraph(object):
def __repr__(self):
return pprint.pformat(self.to_dict())
def merge(self, sub_graph: 'SubGraph'):
def merge(self, sub_graph: "SubGraph"):
self.nodes.extend(sub_graph.nodes)
self.edges.extend(sub_graph.edges)
@ -164,21 +201,30 @@ class SubGraph(object):
for record in spg_records:
s_id = record.id
s_name = record.name
s_label = record.spg_type_name.split('.')[-1]
s_label = record.spg_type_name.split(".")[-1]
properties = record.properties
spg_type = spg_types.get(record.spg_type_name)
for prop_name, prop_value in record.properties.items():
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name.split('.')[-1]
o_label = prop.object_type_name.split(".")[-1]
if o_label not in BASIC_TYPES:
prop_value_list = prop_value.split(',')
prop_value_list = prop_value.split(",")
for o_id in prop_value_list:
sub_graph.add_edge(s_id=s_id, s_label=s_label, p=prop_name, o_id=o_id, o_label=o_label)
sub_graph.add_edge(
s_id=s_id,
s_label=s_label,
p=prop_name,
o_id=o_id,
o_label=o_label,
)
properties.pop(prop_name)
sub_graph.add_node(id=s_id, name=s_name, label=s_label, properties=properties)
sub_graph.add_node(
id=s_id, name=s_name, label=s_label, properties=properties
)
return sub_graph

View File

@ -9,4 +9,3 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.prompt.default.ner import OpenIENERPrompt as DefaultOpenIENERPrompt
from kag.builder.prompt.default.std import (
OpenIEEntitystandardizationdPrompt as DefaultOpenIEEntitystandardizationdPrompt,
)
from kag.builder.prompt.default.triple import (
OpenIETriplePrompt as DefaultOpenIETriplePrompt,
)
from kag.builder.prompt.medical.ner import OpenIENERPrompt as MedicalOpenIENERPrompt
from kag.builder.prompt.medical.std import (
OpenIEEntitystandardizationdPrompt as MedicalOpenIEEntitystandardizationdPrompt,
)
from kag.builder.prompt.medical.triple import (
OpenIETriplePrompt as MedicalOpenIETriplePrompt,
)
from kag.builder.prompt.analyze_table_prompt import AnalyzeTablePrompt
from kag.builder.prompt.spg_prompt import SPGPrompt, SPGEntityPrompt, SPGEventPrompt
from kag.builder.prompt.semantic_seg_prompt import SemanticSegPrompt
from kag.builder.prompt.outline_prompt import OutlinePrompt
__all__ = [
"DefaultOpenIENERPrompt",
"DefaultOpenIEEntitystandardizationdPrompt",
"DefaultOpenIETriplePrompt",
"MedicalOpenIENERPrompt",
"MedicalOpenIEEntitystandardizationdPrompt",
"MedicalOpenIETriplePrompt",
"AnalyzeTablePrompt",
"OutlinePrompt",
"SemanticSegPrompt",
"SPGPrompt",
"SPGEntityPrompt",
"SPGEventPrompt",
]

View File

@ -13,34 +13,24 @@
import json
import logging
from kag.common.base.prompt_op import PromptOp
from kag.interface import PromptABC
logger = logging.getLogger(__name__)
class AnalyzeTablePrompt(PromptOp):
@PromptABC.register("analyze_table")
class AnalyzeTablePrompt(PromptABC):
template_zh: str = """你是一个分析表格的专家, 从table中提取信息并分析最后返回表格有效信息"""
template_en: str = """You are an expert in knowledge graph extraction. Based on the schema defined by the constraint, extract all entities and their attributes from the input. Return NAN for attributes not explicitly mentioned in the input. Output the results in standard JSON format, as a list."""
def __init__(
self,
language: str = "zh",
):
super().__init__(
language=language,
)
def build_prompt(self, variables) -> str:
return json.dumps(
{
"instruction": self.template,
"table": variables.get("table",""),
"table": variables.get("table", ""),
},
ensure_ascii=False,
)
def parse_response(self, response: str, **kwargs):
return response

View File

@ -12,66 +12,66 @@
import json
from string import Template
from typing import List, Optional
from kag.common.base.prompt_op import PromptOp
from typing import List
from kag.common.conf import KAG_PROJECT_CONF
from kag.interface import PromptABC
from knext.schema.client import SchemaClient
class OpenIENERPrompt(PromptOp):
@PromptABC.register("default_ner")
class OpenIENERPrompt(PromptABC):
template_en = """
{
"instruction": "You're a very effective entity extraction system. Please extract all the entities that are important for knowledge build and question, along with type, category and a brief description of the entity. The description of the entity is based on your OWN KNOWLEDGE AND UNDERSTANDING and does not need to be limited to the context. the entity's category belongs taxonomically to one of the items defined by schema, please also output the category. Note: Type refers to a specific, well-defined classification, such as Professor, Actor, while category is a broader group or class that may contain more than one type, such as Person, Works. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.You can refer to the example for extraction.",
"schema": $schema,
"example": [
{
"input": "The Rezort\nThe Rezort is a 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger.\n It stars Dougray Scott, Jessica De Gouw and Martin McCann.\n After humanity wins a devastating war against zombies, the few remaining undead are kept on a secure island, where they are hunted for sport.\n When something goes wrong with the island's security, the guests must face the possibility of a new outbreak.",
"input": "The Rezort is a 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger. It stars Dougray Scott, Jessica De Gouw and Martin McCann. After humanity wins a devastating war against zombies, the few remaining undead are kept on a secure island, where they are hunted for sport. When something goes wrong with the island's security, the guests must face the possibility of a new outbreak.",
"output": [
{
"entity": "The Rezort",
"name": "The Rezort",
"type": "Movie",
"category": "Works",
"description": "A 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger."
},
{
"entity": "2015",
"name": "2015",
"type": "Year",
"category": "Date",
"description": "The year the movie 'The Rezort' was released."
},
{
"entity": "British",
"name": "British",
"type": "Nationality",
"category": "GeographicLocation",
"description": "Great Britain, the island that includes England, Scotland, and Wales."
},
{
"entity": "Steve Barker",
"name": "Steve Barker",
"type": "Director",
"category": "Person",
"description": "Steve Barker is an English film director and screenwriter."
},
{
"entity": "Paul Gerstenberger",
"name": "Paul Gerstenberger",
"type": "Writer",
"category": "Person",
"description": "Paul is a writer and producer, known for The Rezort (2015), Primeval (2007) and House of Anubis (2011)."
},
{
"entity": "Dougray Scott",
"name": "Dougray Scott",
"type": "Actor",
"category": "Person",
"description": "Stephen Dougray Scott (born 26 November 1965) is a Scottish actor."
},
{
"entity": "Jessica De Gouw",
"name": "Jessica De Gouw",
"type": "Actor",
"category": "Person",
"description": "Jessica Elise De Gouw (born 15 February 1988) is an Australian actress. "
},
{
"entity": "Martin McCann",
"name": "Martin McCann",
"type": "Actor",
"category": "Person",
"description": "Martin McCann is an actor from Northern Ireland. In 2020, he was listed as number 48 on The Irish Times list of Ireland's greatest film actors"
@ -89,52 +89,52 @@ class OpenIENERPrompt(PromptOp):
"schema": $schema,
"example": [
{
"input": "《Rezort》\n《Rezort》是一部 2015 年英国僵尸恐怖片,由史蒂夫·巴克执导,保罗·格斯滕伯格编剧。\n 该片由道格瑞·斯科特、杰西卡·德·古维和马丁·麦凯恩主演。\n 在人类赢得与僵尸的毁灭性战争后,剩下的少数不死生物被关在一个安全的岛屿上,在那里他们被猎杀作为消遣。\n 当岛上的安全出现问题时,客人们必须面对新一轮疫情爆发的可能性。",
"input": "《Rezort》是一部 2015年英国僵尸恐怖片由史蒂夫·巴克执导保罗·格斯滕伯格编剧。该片由道格瑞·斯科特、杰西卡·德·古维和马丁·麦凯恩主演。在人类赢得与僵尸的毁灭性战争后剩下的少数不死生物被关在一个安全的岛屿上在那里他们被猎杀作为消遣。当岛上的安全出现问题时客人们必须面对新一轮疫情爆发的可能性。",
"output": [
{
"entity": "The Rezort",
"name": "The Rezort",
"type": "Movie",
"category": "Works",
"description": "一部 2015 年英国僵尸恐怖片,由史蒂夫·巴克执导,保罗·格斯滕伯格编剧。"
},
{
"entity": "2015",
"name": "2015",
"type": "Year",
"category": "Date",
"description": "电影《The Rezort》上映的年份。"
},
{
"entity": "英国",
"name": "英国",
"type": "Nationality",
"category": "GeographicLocation",
"description": "大不列颠,包括英格兰、苏格兰和威尔士的岛屿。"
},
{
"entity": "史蒂夫·巴克",
"name": "史蒂夫·巴克",
"type": "Director",
"category": "Person",
"description": "史蒂夫·巴克 是一名英国电影导演和剧作家"
},
{
"entity": "保罗·格斯滕伯格",
"name": "保罗·格斯滕伯格",
"type": "Writer",
"category": "Person",
"description": "保罗·格斯滕伯格 (Paul Gerstenberger) 是一名作家和制片人因《The Rezort》2015 年、《Primeval》2007 年和《House of Anubis》2011 年)而闻名。"
},
{
"entity": "道格雷·斯科特",
"name": "道格雷·斯科特",
"type": "Actor",
"category": "Person",
"description": "斯蒂芬·道格雷·斯科特 (Stephen Dougray Scott1965 年 11 月 26 日出生) 是一位苏格兰演员。"
},
{
"entity": "杰西卡·德·古维",
"name": "杰西卡·德·古维",
"type": "Actor",
"category": "Person",
"description": "杰西卡·伊莉斯·德·古维 (Jessica Elise De Gouw1988 年 2 月 15 日出生) 是一位澳大利亚女演员。"
},
{
"entity": "马丁·麦肯",
"name": "马丁·麦肯",
"type": "Actor",
"category": "Person",
"description": "马丁·麦肯是来自北爱尔兰的演员。2020 年,他在《爱尔兰时报》爱尔兰最伟大电影演员名单中排名第 48 位"
@ -146,12 +146,14 @@ class OpenIENERPrompt(PromptOp):
}
"""
def __init__(
self, language: Optional[str] = "en", **kwargs
):
def __init__(self, language: str = "", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
self.schema = SchemaClient(
project_id=KAG_PROJECT_CONF.project_id
).extract_types()
self.template = Template(self.template).safe_substitute(
schema=json.dumps(self.schema)
)
@property
def template_variables(self) -> List[str]:

View File

@ -11,65 +11,66 @@
# or implied.
import json
from typing import Optional, List
from typing import List
from kag.common.base.prompt_op import PromptOp
from kag.interface import PromptABC
class OpenIEEntitystandardizationdPrompt(PromptOp):
@PromptABC.register("default_std")
class OpenIEEntitystandardizationdPrompt(PromptABC):
template_en = """
{
"instruction": "The `input` field contains a user provided context. The `named_entities` field contains extracted named entities from the context, which may be unclear abbreviations, aliases, or slang. To eliminate ambiguity, please attempt to provide the official names of these entities based on the context and your own knowledge. Note that entities with the same meaning can only have ONE official name. Please respond in the format of a single JSONArray string without any explanation, as shown in the `output` field of the provided example.",
"example": {
"input": "American History\nWhen did the political party that favored harsh punishment of southern states after the Civil War, gain control of the House? Republicans regained control of the chamber they had lost in the 2006 midterm elections.",
"input": "American History.When did the political party that favored harsh punishment of southern states after the Civil War, gain control of the House? Republicans regained control of the chamber they had lost in the 2006 midterm elections.",
"named_entities": [
{"entity": "American", "category": "GeographicLocation"},
{"entity": "political party", "category": "Organization"},
{"entity": "southern states", "category": "GeographicLocation"},
{"entity": "Civil War", "category": "Keyword"},
{"entity": "House", "category": "Organization"},
{"entity": "Republicans", "category": "Organization"},
{"entity": "chamber", "category": "Organization"},
{"entity": "2006 midterm elections", "category": "Date"}
{"name": "American", "category": "GeographicLocation"},
{"name": "political party", "category": "Organization"},
{"name": "southern states", "category": "GeographicLocation"},
{"name": "Civil War", "category": "Keyword"},
{"name": "House", "category": "Organization"},
{"name": "Republicans", "category": "Organization"},
{"name": "chamber", "category": "Organization"},
{"name": "2006 midterm elections", "category": "Date"}
],
"output": [
{
"entity": "American",
"name": "American",
"category": "GeographicLocation",
"official_name": "United States of America"
},
{
"entity": "political party",
"name": "political party",
"category": "Organization",
"official_name": "Radical Republicans"
},
{
"entity": "southern states",
"name": "southern states",
"category": "GeographicLocation",
"official_name": "Confederacy"
},
{
"entity": "Civil War",
"name": "Civil War",
"category": "Keyword",
"official_name": "American Civil War"
},
{
"entity": "House",
"name": "House",
"category": "Organization",
"official_name": "United States House of Representatives"
},
{
"entity": "Republicans",
"name": "Republicans",
"category": "Organization",
"official_name": "Republican Party"
},
{
"entity": "chamber",
"name": "chamber",
"category": "Organization",
"official_name": "United States House of Representatives"
},
{
"entity": "midterm elections",
"name": "midterm elections",
"category": "Date",
"official_name": "United States midterm elections"
}
@ -84,26 +85,26 @@ class OpenIEEntitystandardizationdPrompt(PromptOp):
{
"instruction": "input字段包含用户提供的上下文。命名实体字段包含从上下文中提取的命名实体这些可能是含义不明的缩写、别名或俚语。为了消除歧义请尝试根据上下文和您自己的知识提供这些实体的官方名称。请注意具有相同含义的实体只能有一个官方名称。请按照提供的示例中的输出字段格式以单个JSONArray字符串形式回复无需任何解释。",
"example": {
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"named_entities": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
{"name": "烦躁不安", "category": "Symptom"},
{"name": "语妄", "category": "Symptom"},
{"name": "失眠", "category": "Symptom"},
{"name": "镇静药", "category": "Medicine"},
{"name": "肺外感染", "category": "Disease"},
{"name": "胸腔压力调节管", "category": "MedicalEquipment"},
{"name": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"name": "闭式负压吸引", "category": "SurgicalOperation"}
],
"output": [
{"entity": "烦躁不安", "category": "Symptom", "official_name": "焦虑不安"},
{"entity": "语妄", "category": "Symptom", "official_name": "谵妄"},
{"entity": "失眠", "category": "Symptom", "official_name": "失眠症"},
{"entity": "镇静药", "category": "Medicine", "official_name": "镇静剂"},
{"entity": "肺外感染", "category": "Disease", "official_name": "肺外感染"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment", "official_name": "胸腔引流管"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment", "official_name": "负压吸引装置"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation", "official_name": "闭式负压引流"}
{"name": "烦躁不安", "category": "Symptom", "official_name": "焦虑不安"},
{"name": "语妄", "category": "Symptom", "official_name": "谵妄"},
{"name": "失眠", "category": "Symptom", "official_name": "失眠症"},
{"name": "镇静药", "category": "Medicine", "official_name": "镇静剂"},
{"name": "肺外感染", "category": "Disease", "official_name": "肺外感染"},
{"name": "胸腔压力调节管", "category": "MedicalEquipment", "official_name": "胸腔引流管"},
{"name": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment", "official_name": "负压吸引装置"},
{"name": "闭式负压吸引", "category": "SurgicalOperation", "official_name": "闭式负压引流"}
]
},
"input": $input,
@ -111,15 +112,11 @@ class OpenIEEntitystandardizationdPrompt(PromptOp):
}
"""
def __init__(self, language: Optional[str] = "en"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["input", "named_entities"]
def parse_response(self, response: str, **kwargs):
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
@ -134,10 +131,10 @@ class OpenIEEntitystandardizationdPrompt(PromptOp):
entities = kwargs.get("named_entities", [])
for entity in standardized_entity:
merged.append(entity)
entities_with_offical_name.add(entity["entity"])
entities_with_offical_name.add(entity["name"])
# in case llm ignores some entities
for entity in entities:
if entity["entity"] not in entities_with_offical_name:
entity["official_name"] = entity["entity"]
if entity["name"] not in entities_with_offical_name:
entity["official_name"] = entity["name"]
merged.append(entity)
return merged

View File

@ -11,66 +11,67 @@
# or implied.
import json
from typing import Optional, List
from typing import List
from kag.common.base.prompt_op import PromptOp
from kag.interface import PromptABC
class OpenIETriplePrompt(PromptOp):
@PromptABC.register("default_triple")
class OpenIETriplePrompt(PromptABC):
template_en = """
{
"instruction": "You are an expert specializing in carrying out open information extraction (OpenIE). Please extract any possible relations (including subject, predicate, object) from the given text, and list them following the json format {\"triples\": [[\"subject\", \"predicate\", \"object\"]]}\n. If there are none, do not list them.\n.\n\nPay attention to the following requirements:\n- Each triple should contain at least one, but preferably two, of the named entities in the entity_list.\n- Clearly resolve pronouns to their specific names to maintain clarity.",
"instruction": "You are an expert specializing in carrying out open information extraction (OpenIE). Please extract any possible relations (including subject, predicate, object) from the given text, and list them following the json format {\"triples\": [[\"subject\", \"predicate\", \"object\"]]}. If there are none, do not list them..Pay attention to the following requirements:- Each triple should contain at least one, but preferably two, of the named entities in the entity_list.- Clearly resolve pronouns to their specific names to maintain clarity.",
"entity_list": $entity_list,
"input": "$input",
"example": {
"input": "The Rezort\nThe Rezort is a 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger.\n It stars Dougray Scott, Jessica De Gouw and Martin McCann.\n After humanity wins a devastating war against zombies, the few remaining undead are kept on a secure island, where they are hunted for sport.\n When something goes wrong with the island's security, the guests must face the possibility of a new outbreak.",
"input": "The RezortThe Rezort is a 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger. It stars Dougray Scott, Jessica De Gouw and Martin McCann. After humanity wins a devastating war against zombies, the few remaining undead are kept on a secure island, where they are hunted for sport. When something goes wrong with the island's security, the guests must face the possibility of a new outbreak.",
"entity_list": [
{
"entity": "The Rezort",
"name": "The Rezort",
"category": "Works"
},
{
"entity": "2015",
"name": "2015",
"category": "Others"
},
{
"entity": "British",
"name": "British",
"category": "GeographicLocation"
},
{
"entity": "Steve Barker",
"name": "Steve Barker",
"category": "Person"
},
{
"entity": "Paul Gerstenberger",
"name": "Paul Gerstenberger",
"category": "Person"
},
{
"entity": "Dougray Scott",
"name": "Dougray Scott",
"category": "Person"
},
{
"entity": "Jessica De Gouw",
"name": "Jessica De Gouw",
"category": "Person"
},
{
"entity": "Martin McCann",
"name": "Martin McCann",
"category": "Person"
},
{
"entity": "zombies",
"name": "zombies",
"category": "Creature"
},
{
"entity": "zombie horror film",
"name": "zombie horror film",
"category": "Concept"
},
{
"entity": "humanity",
"name": "humanity",
"category": "Concept"
},
{
"entity": "secure island",
"name": "secure island",
"category": "GeographicLocation"
}
],
@ -151,16 +152,16 @@ class OpenIETriplePrompt(PromptOp):
"entity_list": $entity_list,
"input": "$input",
"example": {
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"entity_list": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
{"name": "烦躁不安", "category": "Symptom"},
{"name": "语妄", "category": "Symptom"},
{"name": "失眠", "category": "Symptom"},
{"name": "镇静药", "category": "Medicine"},
{"name": "肺外感染", "category": "Disease"},
{"name": "胸腔压力调节管", "category": "MedicalEquipment"},
{"name": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"name": "闭式负压吸引", "category": "SurgicalOperation"}
],
"output":[
["烦躁不安", "酌用", "镇静药"],
@ -178,9 +179,6 @@ class OpenIETriplePrompt(PromptOp):
}
"""
def __init__(self, language: Optional[str] = "en"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["entity_list", "input"]

View File

@ -12,14 +12,14 @@
import json
from string import Template
from typing import List, Optional
from kag.common.base.prompt_op import PromptOp
from typing import List
from kag.common.conf import KAG_PROJECT_CONF
from kag.interface import PromptABC
from knext.schema.client import SchemaClient
class OpenIENERPrompt(PromptOp):
@PromptABC.register("medical_ner")
class OpenIENERPrompt(PromptABC):
template_zh = """
{
"instruction": "你是命名实体识别的专家。请从输入中提取与模式定义匹配的实体。如果不存在该类型的实体请返回一个空列表。请以JSON字符串格式回应。你可以参照example进行抽取。",
@ -28,14 +28,14 @@ class OpenIENERPrompt(PromptOp):
{
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染。\n治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"output": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
{"name": "烦躁不安", "category": "Symptom"},
{"name": "语妄", "category": "Symptom"},
{"name": "失眠", "category": "Symptom"},
{"name": "镇静药", "category": "Medicine"},
{"name": "肺外感染", "category": "Disease"},
{"name": "胸腔压力调节管", "category": "MedicalEquipment"},
{"name": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"name": "闭式负压吸引", "category": "SurgicalOperation"}
]
}
],
@ -45,11 +45,11 @@ class OpenIENERPrompt(PromptOp):
template_en = template_zh
def __init__(
self, language: Optional[str] = "en", **kwargs
):
def __init__(self, language: str = "", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.schema = SchemaClient(
project_id=KAG_PROJECT_CONF.project_id
).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
@property

View File

@ -11,37 +11,37 @@
# or implied.
import json
from typing import Optional, List
from typing import List
from kag.common.base.prompt_op import PromptOp
from kag.interface import PromptABC
class OpenIEEntitystandardizationdPrompt(PromptOp):
@PromptABC.register("medical_std")
class OpenIEEntitystandardizationdPrompt(PromptABC):
template_zh = """
{
"instruction": "input字段包含用户提供的上下文。命名实体字段包含从上下文中提取的命名实体这些可能是含义不明的缩写、别名或俚语。为了消除歧义请尝试根据上下文和您自己的知识提供这些实体的官方名称。请注意具有相同含义的实体只能有一个官方名称。请按照提供的示例中的输出字段格式以单个JSONArray字符串形式回复无需任何解释。",
"example": {
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"named_entities": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
{"name": "烦躁不安", "category": "Symptom"},
{"name": "语妄", "category": "Symptom"},
{"name": "失眠", "category": "Symptom"},
{"name": "镇静药", "category": "Medicine"},
{"name": "肺外感染", "category": "Disease"},
{"name": "胸腔压力调节管", "category": "MedicalEquipment"},
{"name": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"name": "闭式负压吸引", "category": "SurgicalOperation"}
],
"output": [
{"entity": "烦躁不安", "category": "Symptom", "official_name": "焦虑不安"},
{"entity": "语妄", "category": "Symptom", "official_name": "谵妄"},
{"entity": "失眠", "category": "Symptom", "official_name": "失眠症"},
{"entity": "镇静药", "category": "Medicine", "official_name": "镇静剂"},
{"entity": "肺外感染", "category": "Disease", "official_name": "肺外感染"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment", "official_name": "胸腔引流管"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment", "official_name": "负压吸引装置"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation", "official_name": "闭式负压引流"}
{"name": "烦躁不安", "category": "Symptom", "official_name": "焦虑不安"},
{"name": "语妄", "category": "Symptom", "official_name": "谵妄"},
{"name": "失眠", "category": "Symptom", "official_name": "失眠症"},
{"name": "镇静药", "category": "Medicine", "official_name": "镇静剂"},
{"name": "肺外感染", "category": "Disease", "official_name": "肺外感染"},
{"name": "胸腔压力调节管", "category": "MedicalEquipment", "official_name": "胸腔引流管"},
{"name": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment", "official_name": "负压吸引装置"},
{"name": "闭式负压吸引", "category": "SurgicalOperation", "official_name": "闭式负压引流"}
]
},
"input": $input,
@ -51,15 +51,11 @@ class OpenIEEntitystandardizationdPrompt(PromptOp):
template_en = template_zh
def __init__(self, language: Optional[str] = "en"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["input", "named_entities"]
def parse_response(self, response: str, **kwargs):
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
@ -74,10 +70,10 @@ class OpenIEEntitystandardizationdPrompt(PromptOp):
entities = kwargs.get("named_entities", [])
for entity in standardized_entity:
merged.append(entity)
entities_with_offical_name.add(entity["entity"])
entities_with_offical_name.add(entity["name"])
# in case llm ignores some entities
for entity in entities:
if entity["entity"] not in entities_with_offical_name:
entity["official_name"] = entity["entity"]
if entity["name"] not in entities_with_offical_name:
entity["official_name"] = entity["name"]
merged.append(entity)
return merged

View File

@ -11,13 +11,13 @@
# or implied.
import json
from typing import Optional, List, Dict, Any
from typing import List
from kag.common.base.prompt_op import PromptOp
from kag.interface import PromptABC
class OpenIETriplePrompt(PromptOp):
@PromptABC.register("medical_triple")
class OpenIETriplePrompt(PromptABC):
template_zh = """
{
"instruction": "您是一位专门从事开放信息提取OpenIE的专家。请从input字段的文本中提取任何可能的关系包括主语、谓语、宾语并按照JSON格式列出它们须遵循example字段的示例格式。请注意以下要求1. 每个三元组应至少包含entity_list实体列表中的一个但最好是两个命名实体。2. 明确地将代词解析为特定名称,以保持清晰度。",
@ -26,14 +26,14 @@ class OpenIETriplePrompt(PromptOp):
"example": {
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"entity_list": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
{"name": "烦躁不安", "category": "Symptom"},
{"name": "语妄", "category": "Symptom"},
{"name": "失眠", "category": "Symptom"},
{"name": "镇静药", "category": "Medicine"},
{"name": "肺外感染", "category": "Disease"},
{"name": "胸腔压力调节管", "category": "MedicalEquipment"},
{"name": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"name": "闭式负压吸引", "category": "SurgicalOperation"}
],
"output":[
["烦躁不安", "酌用", "镇静药"],
@ -53,9 +53,6 @@ class OpenIETriplePrompt(PromptOp):
template_en = template_zh
def __init__(self, language: Optional[str] = "en"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["entity_list", "input"]

View File

@ -1,518 +0,0 @@
#
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import logging
import re
from abc import ABC
from typing import List, Dict, Any
from collections import defaultdict
from knext.schema.model.schema_helper import SPGTypeName
from kag.builder.model.spg_record import SPGRecord
from kag.builder.prompt.spg_prompt import SPGPrompt
import uuid
logger = logging.getLogger(__name__)
class OneKEPrompt(SPGPrompt, ABC):
template_zh: str = ""
template_en: str = ""
def __init__(self, **kwargs):
types_list = kwargs.get("types_list", [])
language = kwargs.get("language", "zh")
with_description = kwargs.get("with_description", False)
split_num = kwargs.get("split_num", 4)
super().__init__(types_list, **kwargs)
self.language = language
if language == "zh":
self.template = self.template_zh
else:
self.template = self.template_en
self.with_description = with_description
self.split_num = split_num
self._init_render_variables()
self._render()
self.params = kwargs
def build_prompt(self, variables: Dict[str, str]) -> List[str]:
instructions = []
for schema in self.schema_list:
instructions.append(
json.dumps(
{
"instruction": self.template,
"schema": schema,
"input": variables.get("input"),
},
ensure_ascii=False,
)
)
return instructions
def parse_response(self, response: str) -> List[SPGRecord]:
raise NotImplementedError
def _render(self):
raise NotImplementedError
def multischema_split_by_num(self, split_num, schemas: List[Any]):
negative_length = max(len(schemas) // split_num, 1) * split_num
total_schemas = []
for i in range(0, negative_length, split_num):
total_schemas.append(schemas[i : i + split_num])
remain_len = max(1, split_num // 2)
tmp_schemas = schemas[negative_length:]
if len(schemas) - negative_length >= remain_len and len(tmp_schemas) > 0:
total_schemas.append(tmp_schemas)
elif len(tmp_schemas) > 0:
total_schemas[-1].extend(tmp_schemas)
return total_schemas
class OneKE_NERPrompt(OneKEPrompt):
template_zh: str = (
"你是专门进行实体抽取的专家。请从input中抽取出符合schema定义的实体不存在的实体类型返回空列表。请按照JSON字符串的格式回答。"
)
template_en: str = "You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string."
def __init__(
self,
entity_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
types_list=entity_types,
language=language,
with_description=with_description,
split_num=split_num,
**kwargs,
)
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
try:
ent_obj = json.loads(response)
except json.decoder.JSONDecodeError:
logger.error("OneKE_NERPrompt response JSONDecodeError error.")
return []
if type(ent_obj) != dict:
logger.error("OneKE_NERPrompt response type error.")
return []
spg_records = []
for type_zh, values in ent_obj.items():
if type_zh not in self.spg_type_schema_info_zh:
logger.warning(f"Unrecognized entity_type: {type_zh}")
continue
type_en, _ = self.spg_type_schema_info_zh[type_zh]
for value in values:
spg_record = SPGRecord(type_en)
spg_record.upsert_properties({"id": value, "name": value})
spg_records.append(spg_record)
return spg_records
def _render(self):
entity_list = []
for spg_type in self.spg_types:
entity_list.append(spg_type.name_zh)
self.schema_list = self.multischema_split_by_num(self.split_num, entity_list)
class OneKE_SPOPrompt(OneKEPrompt):
template_zh: str = (
"你是专门进行SPO三元组抽取的专家。请从input中抽取出符合schema定义的spo关系三元组不存在的关系返回空列表。请按照JSON字符串的格式回答。"
)
template_en: str = "You are an expert in spo(subject, predicate, object) triples extraction. Please extract SPO relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string."
def __init__(
self,
spo_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
types_list=spo_types,
language=language,
with_description=with_description,
split_num=split_num,
**kwargs,
)
self.properties_mapper = {}
self.relations_mapper = {}
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
try:
re_obj = json.loads(response)
except json.decoder.JSONDecodeError:
logger.error("OneKE_REPrompt response JSONDecodeError error.")
return []
if type(re_obj) != dict:
logger.error("OneKE_REPrompt response type error.")
return []
relation_dcir = defaultdict(list)
for relation_zh, values in re_obj.items():
if relation_zh not in self.property_info_zh[relation_zh]:
logger.warning(f"Unrecognized relation: {relation_zh}")
continue
if values and isinstance(values, list):
for value in values:
if (
type(value) != dict
or "subject" not in value
or "object" not in value
):
logger.warning("OneKE_REPrompt response type error.")
continue
s_zh, o_zh = value.get("subject", ""), value.get("object", "")
relation_dcir[relation_zh].append((s_zh, o_zh))
spg_records = []
for relation_zh, sub_obj_list in relation_dcir.items():
sub_dict = defaultdict(list)
for s_zh, o_zh in sub_obj_list:
sub_dict[s_zh].append(o_zh)
for s_zh, o_list in sub_dict.items():
if s_zh in self.spg_type_schema_info_zh:
logger.warning(f"Unrecognized subject_type: {s_zh}")
continue
object_value = ",".join(o_list)
s_type_zh = self.properties_mapper.get(relation_zh, None)
if s_type_zh is not None:
s_type_en, _ = self.spg_type_schema_info_zh[s_type_zh]
relation_en, _ = self.property_info_zh[relation_zh]
spg_record = SPGRecord(s_type_en).upsert_properties(
{"id": s_zh, "name": s_zh}
)
spg_record.upsert_property(relation_en, object_value)
else:
s_type_zh, o_type_zh = self.relations_mapper.get(
relation_zh, [None, None]
)
if s_type_zh is None or o_type_zh is None:
logger.warning(f"Unrecognized relation: {relation_zh}")
continue
s_type_en, _ = self.spg_type_schema_info_zh[s_type_zh]
spg_record = SPGRecord(s_type_en).upsert_properties(
{"id": s_zh, "name": s_zh}
)
relation_en, _, object_type = self.relation_info_zh[s_type_zh][
relation_zh
]
spg_record.upsert_relation(relation_en, object_type, object_value)
spg_records.append(spg_record)
return spg_records
def _render(self):
spo_list = []
for spg_type in self.spg_types:
type_en, _ = self.spg_type_schema_info_zh[spg_type]
for v in spg_type.properties.values():
spo_list.append(
{
"subject_type": spg_type.name_zh,
"predicate": v.name_zh,
"object_type": "文本",
}
)
self.properties_mapper[v.name_zh] = spg_type
for v in spg_type.relations.values():
_, _, object_type = self.relation_info_en[type_en][v.name]
spo_list.append(
{
"subject_type": spg_type.name_zh,
"predicate": v.name_zh,
"object_type": object_type,
}
)
self.relations_mapper[v.name_zh] = [spg_type, object_type]
self.schema_list = self.multischema_split_by_num(self.split_num, spo_list)
class OneKE_REPrompt(OneKE_SPOPrompt):
template_zh: str = (
"你是专门进行关系抽取的专家。请从input中抽取出符合schema定义的关系三元组不存在的关系返回空列表。请按照JSON字符串的格式回答。"
)
template_en: str = "You are an expert in relationship extraction. Please extract relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string."
def __init__(
self,
relation_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
relation_types, language, with_description, split_num, **kwargs
)
def _render(self):
re_list = []
for spg_type in self.spg_types:
type_en, _ = self.spg_type_schema_info_zh[spg_type]
for v in spg_type.properties.values():
re_list.append(v.name_zh)
self.properties_mapper[v.name_zh] = spg_type
for v in spg_type.relations.values():
v_zh, _, object_type = self.relation_info_en[type_en][v.name]
re_list.append(v.name_zh)
self.relations_mapper[v.name_zh] = [spg_type, object_type]
self.schema_list = self.multischema_split_by_num(self.split_num, re_list)
class OneKE_KGPrompt(OneKEPrompt):
template_zh: str = "你是一个图谱实体知识结构化专家。根据输入实体类型(entity type)的schema描述从文本中抽取出相应的实体实例和其属性信息不存在的属性不输出, 属性存在多值就返回列表并输出为可解析的json格式。"
template_en: str = "You are an expert in structured knowledge systems for graph entities. Based on the schema description of the input entity type, you extract the corresponding entity instances and their attribute information from the text. Attributes that do not exist should not be output. If an attribute has multiple values, a list should be returned. The results should be output in a parsable JSON format."
def __init__(
self,
entity_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
types_list=entity_types,
language=language,
with_description=with_description,
split_num=split_num,
**kwargs,
)
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
try:
re_obj = json.loads(response)
except json.decoder.JSONDecodeError:
logger.error("OneKE_KGPrompt response JSONDecodeError error.")
return []
if type(re_obj) != dict:
logger.error("OneKE_KGPrompt response type error.")
return []
spg_records = []
for type_zh, type_value in re_obj.items():
if type_zh not in self.spg_type_schema_info_zh:
logger.warning(f"Unrecognized entity_type: {type_zh}")
continue
type_en, _ = self.spg_type_schema_info_zh[type_zh]
if type_value and isinstance(type_value, dict):
for name, attrs in type_value.items():
spg_record = SPGRecord(type_en).upsert_properties(
{"id": name, "name": name}
)
for attr_zh, attr_value in attrs.items():
if isinstance(attr_value, list):
attr_value = ",".join(attr_value)
if attr_zh in self.property_info_zh[type_zh]:
attr_en, _, object_type = self.property_info_zh[type_zh][
attr_zh
]
spg_record.upsert_property(attr_en, attr_value)
elif attr_zh in self.relation_info_zh[type_zh]:
attr_en, _, object_type = self.relation_info_zh[type_zh][
attr_zh
]
spg_record.upsert_relation(attr_en, object_type, attr_value)
else:
logger.warning(f"Unrecognized attribute: {attr_zh}")
continue
if object_type == "Integer":
matches = re.findall(r"\d+", attr_value)
if matches:
spg_record.upsert_property(attr_en, matches[0])
elif object_type == "Float":
matches = re.findall(r"\d+(?:\.\d+)?", attr_value)
if matches:
spg_record.upsert_property(attr_en, matches[0])
spg_records.append(spg_record)
return spg_records
def _render(self):
spo_list = []
for spg_type in self.spg_types:
if not self.with_description:
attributes = []
attributes.extend(
[
v.name_zh
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
]
)
attributes.extend(
[
v.name_zh
for k, v in spg_type.relations.items()
if v.name_zh not in attributes
and k not in self.ignored_relations
]
)
else:
attributes = {}
attributes.update(
{
v.name_zh: v.desc or ""
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
}
)
attributes.update(
{
v.name_zh: v.desc or ""
for k, v in spg_type.relations.items()
if v.name_zh not in attributes
and k not in self.ignored_relations
}
)
entity_type = spg_type.name_zh
spo_list.append({"entity_type": entity_type, "attributes": attributes})
self.schema_list = self.multischema_split_by_num(self.split_num, spo_list)
class OneKE_EEPrompt(OneKEPrompt):
template_zh: str = "你是专门进行事件提取的专家。请从input中抽取出符合schema定义的事件不存在的事件返回空列表不存在的论元返回NAN如果论元存在多值请返回列表。请按照JSON字符串的格式回答。"
template_en: str = "You are an expert in event extraction. Please extract events from the input that conform to the schema definition. Return an empty list for events that do not exist, and return NAN for arguments that do not exist. If an argument has multiple values, please return a list. Respond in the format of a JSON string."
def __init__(
self,
event_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
types_list=event_types,
language=language,
with_description=with_description,
split_num=split_num,
**kwargs,
)
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
try:
ee_obj = json.loads(response)
except json.decoder.JSONDecodeError:
logger.error("OneKE_EEPrompt response JSONDecodeError error.")
return []
if type(ee_obj) != dict:
logger.error("OneKE_EEPrompt response type error.")
return []
spg_records = []
for type_zh, type_values in ee_obj.items():
if type_zh not in self.spg_type_schema_info_zh:
logger.warning(f"Unrecognized event_type: {type_zh}")
continue
type_en, _ = self.spg_type_schema_info_zh[type_zh]
if type_values and isinstance(type_values, list):
for type_value in type_values:
uuid_4 = uuid.uuid4()
spg_record = (
SPGRecord(type_en)
.upsert_property("id", str(uuid_4))
.upsert_property("name", type_zh)
)
arguments = type_value.get("arguments")
if arguments and isinstance(arguments, dict):
for attr_zh, attr_value in arguments.items():
if isinstance(attr_value, list):
attr_value = ",".join(attr_value)
if attr_zh in self.property_info_zh[type_zh]:
attr_en, _, object_type = self.property_info_zh[
type_zh
][attr_zh]
spg_record.upsert_property(attr_en, attr_value)
elif attr_zh in self.relation_info_zh[type_zh]:
attr_en, _, object_type = self.relation_info_zh[
type_zh
][attr_zh]
spg_record.upsert_relation(
attr_en, object_type, attr_value
)
else:
logger.warning(f"Unrecognized attribute: {attr_zh}")
continue
if object_type == "Integer":
matches = re.findall(r"\d+", attr_value)
if matches:
spg_record.upsert_property(attr_en, matches[0])
elif object_type == "Float":
matches = re.findall(r"\d+(?:\.\d+)?", attr_value)
if matches:
spg_record.upsert_property(attr_en, matches[0])
spg_records.append(spg_record)
return spg_records
def _render(self):
event_list = []
for spg_type in self.spg_types:
if not self.with_description:
arguments = []
arguments.extend(
[
v.name_zh
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
]
)
arguments.extend(
[
v.name_zh
for k, v in spg_type.relations.items()
if v.name_zh not in arguments
and k not in self.ignored_relations
]
)
else:
arguments = {}
arguments.update(
{
v.name_zh: v.desc or ""
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
}
)
arguments.update(
{
v.name_zh: v.desc or ""
for k, v in spg_type.relations.items()
if v.name_zh not in arguments
and k not in self.ignored_relations
}
)
event_type = spg_type.name_zh
event_list.append(
{"event_type": event_type, "trigger": True, "arguments": arguments}
)
self.schema_list = self.multischema_split_by_num(self.split_num, event_list)

View File

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Optional, List
from kag.interface import PromptABC
import ast
@PromptABC.register("outline_align")
class OutlineAlignPrompt(PromptABC):
template_zh = """
{
"instruction": "请分析以下大纲列表,统一调整标题的层级。遵循以下规则:
1. 相同类型的标题应该有相同的层级例如所有'第X章'都应该是同一层级
2. 层级关系应该符合逻辑例如:
- (1) > (2) > (3)
- 部分(1) > (2) > (3)
3. 考虑标题的上下文关系确保层级的连贯性
4. 如果标题不含明确的层级标识根据其内容和上下文推断合适的层级
请务必按照以下格式返回不要返回其他任何内容请返回调整后的大纲列表格式为:
[(标题1, 层级1), (标题2, 层级2), ...]
输入的大纲列表为:
$outlines",
"example": [
{
"input": [
("第一章 绪论", 2),
("第一节 研究背景", 1),
("第二章 文献综述", 1),
("第二节 研究方法", 2)
],
"output": [
("第一章 绪论", 1),
("第一节 研究背景", 2),
("第二章 文献综述", 1),
("第二节 研究方法", 2)
]
}
]
}
"""
template_en = """
{
"instruction": "Please analyze the following outline list and unify the levels of titles according to these rules:
1. Similar types of titles should have the same level (e.g., all 'Chapter X' should be at the same level)
2. Level relationships should follow logic, e.g.:
- Chapter(1) > Section(2) > Article(3)
- Part(1) > Chapter(2) > Section(3)
3. Consider context relationships between titles to ensure level continuity
4. For titles without clear level indicators, infer appropriate levels based on content and context
Please return the adjusted outline list in the format:
[(title1, level1), (title2, level2), ...]
Input outline list:
$outlines",
"example": [
{
"input": [
("Chapter 1 Introduction", 2),
("Section 1.1 Background", 1),
("Chapter 2 Literature Review", 1),
("Section 2.1 Methods", 2)
],
"output": [
("Chapter 1 Introduction", 1),
("Section 1.1 Background", 2),
("Chapter 2 Literature Review", 1),
("Section 2.1 Methods", 2)
]
}
]
}
"""
def __init__(self, language: Optional[str] = "zh"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["outlines"]
def parse_response(self, response: str, **kwargs):
if isinstance(response, str):
cleaned_data = response.strip("`python\n[] \n")
cleaned_data = "[" + cleaned_data + "]"
return ast.literal_eval(cleaned_data)
if isinstance(response, dict) and "output" in response:
return response["output"]
return response

View File

@ -10,74 +10,43 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Optional, List
from kag.common.base.prompt_op import PromptOp
from kag.interface import PromptABC
import ast
class OutlinePrompt(PromptOp):
@PromptABC.register("outline")
class OutlinePrompt(PromptABC):
template_zh = """
{
"instruction": "\n请理解input字段中的文本内容识别文本的结构和组成部分并帮我提取出以下内容的标题可能有多个标题分散在文本的各个地方仅返属于原文的回标题文本即可不要返回其他任何内容须按照python list的格式回答具体形式请遵从example字段中给出的若干例子。",
"instruction": "\n给定一段纯文本内容,请提取其中的标题,并返回一个列表。每个标题应包含以下信息:\n- 标题文本\n- 标题级别(例如 1 表示一级标题2 表示二级标题等)\n\n假设标题遵循以下规则:\n1. 标题通常带有数字我们的文本可能是从一些图片OCR生成的所以标题可能隐藏在段落中尽可能找出这些隐藏在段落中带有数字的标题\n2. 标题的级别可以通过以下方式推断:\n - 一级标题:通常是篇章级别的内容。\n - 二级标题:通常是章节级别的内容,具有简洁的文字描述,有时以 \"第X部分\"\"第X章\"\"Part X\" 等类似形式开头。\n - 三级标题及以下:通常是段落或细节级别的标题,可能包含数字编号(如\"1.\"\"1.1\"),或者较长且具体的描述(如\"1.1 子标题\"\"第1节 概述\")。\n3. 标题的级别也可以通过上下文判断:\n - 如果两个标题之间的文本内容非常短(例如少于一定字数),后面的标题可能是更高或相同级别的标题。\n - 连续编号的标题如“第1条”“第2条”通常属于同一级别。\n - 标题层级通常由其数字层次决定例如“1”“1.1”“1.1.1”依次为 1 级、2 级、3 级。\n - 如果一个标题包含关键词如“部分”“章”“节”“条”,且其长度适中(例如 5 至 20 个字符),该标题的级别往往比更长或更短的标题要高。\n4. 以下标题可以直接忽略:\n - 含有纯数字或仅由数字和标点组成的标题例如“1.”、“2.1”等)。\n - 重复出现的标题(例如页眉或页脚被误识别为标题的情况)。\n5. 如果某些内容无法明确判断为标题,或者不符合上述规则,请忽略。\n\n请根据上述规则,返回一个包含标题和对应级别的列表,格式如下:\n[\n (\"标题文本1\", 1),\n (\"标题文本2\", 2),\n (\"标题文本3\", 3),\n ...\n]我还会给你提供之前内容抽取出的目录current_outlines你需要根据当前已经抽取的目录自行判断抽取标题的粒度以及对应的等级",
"input": "$input",
"current_outline:": "$current_outline",
"example": [
{
"input": "第8条 原 则
1.各成员方在制订或修正其法律和规章时可采取必要措施以保护公众健康和营养并促进对其社会经济和技术发展至关重要部门的公众利益只要该措施符合本协议规定
2.可能需要采取与本协议的规定相一致的适当的措施以防止知识产权所有者滥用知识产权或藉以对贸易进行不合理限制或实行对国际间的技术转让产生不利影响的作法
第二部分 关于知识产权的效力范围及使用的标准
第1节 版权及相关权利
第9条 伯尔尼公约的关系",
"input": "第8条 原 则\n\n1.各成员方在制订或修正其法律和规章时,可采取必要措施以保护公众健康和营养,并促进对其社会经济和技术发展至关重要部门的公众利益,只要该措施符合本协议规定。\n\n2.可能需要采取与本协议的规定相一致的适当的措施,以防止知识产权所有者滥用知识产权或藉以对贸易进行不合理限制或实行对国际间的技术转让产生不利影响的作法。\n\n第二部分 关于知识产权的效力、范围及使用的标准\n\n第1节 版权及相关权利\n\n第9条 与《伯尔尼公约》的关系",
"output": [
"第8条 原 则",
"第二部分 关于知识产权的效力、范围及使用的标准",
"第1节 版权及相关权利",
"第9条 与《伯尔尼公约》的关系"
],
("第8条 原 则",3),
("第二部分 关于知识产权的效力、范围及使用的标准",1),
("第1节 版权及相关权利",2),
("第9条 与《伯尔尼公约》的关系",3)
]
},
{
"input": "第16条 授予权利
1.已注册商标所有者应拥有阻止所有未经其同意的第三方在贸易中使用与已注册商标相同或相似的商品或服务的其使用有可能招致混淆的相同或相似的标志在对相同商品或服务使用相同标志的情况下应推定存在混淆之可能上述权利不应妨碍任何现行的优先权也不应影响各成员方以使用为条件获得注册权的可能性
2.1967巴黎公约第6条副则经对细节作必要修改后应适用于服务在确定一个商标是否为知名商标时各成员方应考虑到有关部分的公众对该商标的了解包括由于该商标的推行而在有关成员方得到的了解
3.1967巴黎公约第6条副则经对细节作必要修改后应适用于与已注册商标的商品和服务不相似的商品或服务条件是该商标与该商品和服务有关的使用会表明该商品或服务与已注册商标所有者之间的联系而且已注册商标所有者的利益有可能为此种使用所破坏
第17条  \n ",
"input": "第16条 授予权利\n\n1.已注册商标所有者应拥有阻止所有未经其同意的第三方在贸易中使用与已注册商标相同或相似的商品或服务的,其使用有可能招致混淆的相同或相似的标志。在对相同商品或服务使用相同标志的情况下,应推定存在混淆之可能。上述权利不应妨碍任何现行的优先权,也不应影响各成员方以使用为条件获得注册权的可能性。\n\n2.1967《巴黎公约》第6条副则经对细节作必要修改后应适用于服务。在确定一个商标是否为知名商标时各成员方应考虑到有关部分的公众对该商标的了解包括由于该商标的推行而在有关成员方得到的了解。\n\n3.1967《巴黎公约》第6条副则经对细节作必要修改后应适用于与已注册商标的商品和服务不相似的商品或服务条件是该商标与该商品和服务有关的使用会表明该商品或服务与已注册商标所有者之间的联系而且已注册商标所有者的利益有可能为此种使用所破坏。\n\n第17条 例 外\n ",
"output": [
"第16条 授予权利",
"第17条 例 外"
],
("第16条 授予权利",3),
("第17条 例 外",3)
]
},
{
"input":"的做法。
4此类使用应是非独占性的
5此类使用应是不可转让的除非是同享有此类使用的那部分企业或信誉一道转让
6任何此类使用之授权均应主要是为授权此类使用的成员方国内市场供应之目的
7在被授权人的合法利益受到充分保护的条件下当导致此类使用授权的情况下不复存在和可能不再产生时有义务将其终止应有动机的请求主管当局应有权对上述情况的继续存在进行检查
8考虑到授权的经济价值应视具体情况向权利人支付充分的补偿金
9任何与此类使用之授权有关的决定其法律效力应接受该成员方境内更高当局的司法审查或其他独立审查
10任何与为此类使用而提供的补偿金有关的决定应接受成员方境内更高当局的司法审查或其他独立审查
",
"output": [],
},
"input": "的做法。\n\n4此类使用应是非独占性的。\n\n5此类使用应是不可转让的除非是同享有此类使用的那部分企业或信誉一道转让。\n\n6任何此类使用之授权均应主要是为授权此类使用的成员方国内市场供应之目的。\n\n7在被授权人的合法利益受到充分保护的条件下当导致此类使用授权的情况下不复存在和可能不再产生时有义务将其终止应有动机的请求主管当局应有权对上述情况的继续存在进行检查。\n\n8考虑到授权的经济价值应视具体情况向权利人支付充分的补偿金。\n\n9任何与此类使用之授权有关的决定其法律效力应接受该成员方境内更高当局的司法审查或其他独立审查。\n\n10任何与为此类使用而提供的补偿金有关的决定应接受成员方境内更高当局的司法审查或其他独立审查。\n",
"output": []
}
]
}
"""
}
"""
template_en = """
{
@ -147,11 +116,16 @@ Article 17 Exceptions
@property
def template_variables(self) -> List[str]:
return ["input"]
return ["input", "current_outline"]
def parse_response(self, response: str, **kwargs):
if isinstance(response, str):
response = json.loads(response)
cleaned_data = response.strip("`python\n[] \n") # 去除 Markdown 语法和多余的空格
cleaned_data = "[" + cleaned_data + "]" # 恢复为列表格式
# 使用 ast.literal_eval 将字符串转换为实际的列表对象
list_data = ast.literal_eval(cleaned_data)
return list_data
if isinstance(response, dict) and "output" in response:
response = response["output"]

View File

@ -11,12 +11,13 @@
# or implied.
import json
from typing import Optional, List
from typing import List
from kag.common.base.prompt_op import PromptOp
from kag.interface import PromptABC
class SemanticSegPrompt(PromptOp):
@PromptABC.register("semantic_seg")
class SemanticSegPrompt(PromptABC):
template_zh = """
{
"instruction": "\n请理解input字段中的文本内容识别文本的结构和组成部分并按照语义主题确定分割点将其切分成互不重叠的若干小节。如果文章有章节等可识别的结构信息请直接按照顶层结构进行切分。\n请按照schema定义的字段返回包含小节摘要和小节起始点。须按照JSON字符串的格式回答。具体形式请遵从example字段中给出的若干例子。",
@ -111,9 +112,6 @@ class SemanticSegPrompt(PromptOp):
}
"""
def __init__(self, language: Optional[str] = "zh"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["input"]

View File

@ -12,244 +12,589 @@
import json
import logging
from abc import ABC
import copy
from typing import List, Dict
from kag.common.base.prompt_op import PromptOp
from kag.interface import PromptABC
from knext.schema.client import SchemaClient
from knext.schema.model.base import BaseSpgType, SpgTypeEnum
from knext.schema.model.base import SpgTypeEnum, ConstraintTypeEnum
from knext.schema.model.schema_helper import SPGTypeName
from kag.builder.model.spg_record import SPGRecord
from kag.common.conf import KAG_PROJECT_CONF
from knext.schema.client import OTHER_TYPE
logger = logging.getLogger(__name__)
class SPGPrompt(PromptOp, ABC):
spg_types: Dict[str, BaseSpgType]
class SPGPrompt(PromptABC):
"""
Base class for generating SPG schema-based entity/event extraction prompts.
Attributes:
ignored_types (List[str]): List of SPG types to be ignored.
ignored_properties (List[str]): List of properties to be ignored.
default_properties (Dict[str, str]): Default properties for SPG types.
ignored_relations (List[str]): List of relations to be ignored.
"""
ignored_types: List[str] = ["Chunk"]
ignored_properties: List[str] = ["id", "name", "description", "stdId", "eventTime", "desc", "semanticType"]
ignored_properties: List[str] = [
"id",
"stdId",
"desc",
"description",
"eventTime",
]
default_properties: Dict[str, str] = {
"name": "Text",
}
ignored_relations: List[str] = ["isA"]
basic_types = {"Text": "文本", "Integer": "整型", "Float": "浮点型"}
def __init__(
self,
spg_type_names: List[SPGTypeName],
language: str = "zh",
spg_type_names: List[SPGTypeName] = [],
language: str = "",
**kwargs,
):
"""
Initializes the SPGPrompt instance.
Args:
spg_type_names (List[SPGTypeName], optional): List of SPG type names. Defaults to [].
language (str, optional): Language for the prompt. Defaults to "".
**kwargs: Additional keyword arguments.
"""
super().__init__(language=language, **kwargs)
self.all_schema_types = SchemaClient(project_id=self.project_id).load()
self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
self.spg_type_names = spg_type_names
if not spg_type_names:
self.spg_types = self.all_schema_types
self.spg_types = self.schema
else:
self.spg_types = {k: v for k, v in self.all_schema_types.items() if k in spg_type_names}
self.schema_list = []
self._init_render_variables()
self.spg_types = {
k: v for k, v in self.schema.items() if k in spg_type_names
}
self.create_prompt_schema()
# self._init_render_variables()
@property
def template_variables(self) -> List[str]:
"""
Returns the list of template variables used in the prompt.
Returns:
List[str]: List of template variables.
"""
return ["schema", "input"]
def _init_render_variables(self):
self.type_en_to_zh = {"Text": "文本", "Integer": "整型", "Float": "浮点型"}
self.type_zh_to_en = {
"文本": "Text",
"整型": "Integer",
"浮点型": "Float",
}
self.prop_en_to_zh = {}
self.prop_zh_to_en = {}
for type_name, spg_type in self.all_schema_types.items():
self.type_en_to_zh[type_name] = spg_type.name_zh
self.type_en_to_zh[spg_type.name_zh] = type_name
self.prop_zh_to_en[type_name] = {}
self.prop_en_to_zh[type_name] = {}
for _prop in spg_type.properties.values():
if _prop.name in self.ignored_properties:
continue
self.prop_en_to_zh[type_name][_prop.name] = _prop.name_zh
self.prop_zh_to_en[type_name][_prop.name_zh] = _prop.name
for _rel in spg_type.relations.values():
if _rel.is_dynamic:
continue
self.prop_en_to_zh[type_name][_rel.name] = _rel.name_zh
self.prop_zh_to_en[type_name][_rel.name_zh] = _rel.name
def get_accept_types(self):
"""
Returns the list of accepted SPG types.
def _render(self):
raise NotImplementedError
class SPG_KGPrompt(SPGPrompt):
template_zh: str = """
{
"instruction": "你是一个图谱知识抽取的专家, 基于constraint 定义的schema从input 中抽取出所有的实体及其属性input中未明确提及的属性返回NAN以标准json 格式输出结果返回list",
"schema": $schema,
"example": [
{
"input": "甲状腺结节是指在甲状腺内的肿块可随吞咽动作随甲状腺而上下移动是临床常见的病症可由多种病因引起。临床上有多种甲状腺疾病如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发也可以多发多发结节比单发结节的发病率高但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科甲状腺外科内分泌科头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下甲状腺结节没有任何症状甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。",
"schema": {
"Disease": {
"properties": {
"complication": "并发症",
"commonSymptom": "常见症状",
"applicableMedicine": "适用药品",
"department": "就诊科室",
"diseaseSite": "发病部位",
}
},"Medicine": {
"properties": {
}
}
}
"output": [
{
"entity": "甲状腺结节",
"category":"Disease"
"properties": {
"complication": "甲状腺癌",
"commonSymptom": ["颈部疼痛", "咽喉部异物感", "压迫感"],
"applicableMedicine": ["复方碘口服液(Lugol液)", "丙基硫氧嘧啶(PTU)", "甲基硫氧嘧啶(MTU)", "甲硫咪唑", "卡比马唑"],
"department": ["普外科", "甲状腺外科", "内分泌科", "头颈外科"],
"diseaseSite": "甲状腺",
}
},{
"entity":"复方碘口服液(Lugol液)",
"category":"Medicine"
},{
"entity":"丙基硫氧嘧啶(PTU)",
"category":"Medicine"
},{
"entity":"甲基硫氧嘧啶(MTU)",
"category":"Medicine"
},{
"entity":"甲硫咪唑",
"category":"Medicine"
},{
"entity":"卡比马唑",
"category":"Medicine"
}
],
"input": "$input"
}
"""
template_en: str = """
{
"instruction": "You are an expert in knowledge graph extraction. Based on the schema defined by constraints, extract all entities and their attributes from the input. For attributes not explicitly mentioned in the input, return NAN. Output the results in standard JSON format as a list.",
"schema": $schema,
"example": [
{
"input": "Thyroid nodules refer to lumps within the thyroid gland that can move up and down with swallowing, and they are a common clinical condition that can be caused by various etiologies. Clinically, many thyroid diseases, such as thyroid degeneration, inflammation, autoimmune conditions, and neoplasms, can present as nodules. Thyroid nodules can occur singly or in multiple forms; multiple nodules have a higher incidence than single nodules, but single nodules have a higher likelihood of being thyroid cancer. Patients typically have the option to register for consultation in general surgery, thyroid surgery, endocrinology, or head and neck surgery. Some patients can feel the nodules in the front of their neck. In most cases, thyroid nodules are asymptomatic, and thyroid function is normal. The probability of thyroid nodules progressing to other thyroid diseases is only about 1%. Some individuals may experience neck pain, a foreign body sensation in the throat, or a feeling of pressure. When spontaneous intracystic bleeding occurs in a thyroid nodule, the pain can be more intense. Treatment options generally include radioactive iodine therapy, Lugol's solution (a compound iodine oral solution), or antithyroid medications to suppress thyroid hormone secretion. Currently, commonly used antithyroid drugs are thiourea compounds, including propylthiouracil (PTU) and methylthiouracil (MTU) from the thiouracil class, and methimazole and carbimazole from the imidazole class.",
"schema": {
"Disease": {
"properties": {
"complication": "Disease",
"commonSymptom": "Symptom",
"applicableMedicine": "Medicine",
"department": "HospitalDepartment",
"diseaseSite": "HumanBodyPart"
}
},"Medicine": {
"properties": {
}
}
}
"output": [
{
"entity": "Thyroid Nodule",
"category": "Disease",
"properties": {
"complication": "Thyroid Cancer",
"commonSymptom": ["Neck Pain", "Foreign Body Sensation in the Throat", "Feeling of Pressure"],
"applicableMedicine": ["Lugol's Solution (Compound Iodine Oral Solution)", "Propylthiouracil (PTU)", "Methylthiouracil (MTU)", "Methimazole", "Carbimazole"],\n "department": ["General Surgery", "Thyroid Surgery", "Endocrinology", "Head and Neck Surgery"],\n "diseaseSite": "Thyroid"\n }\n },\n {\n "entity": "Lugol's Solution (Compound Iodine Oral Solution)",
"category": "Medicine"
},
{
"entity": "Propylthiouracil (PTU)",
"category": "Medicine"
},
{
"entity": "Methylthiouracil (MTU)",
"category": "Medicine"
},
{
"entity": "Methimazole",
"category": "Medicine"
},
{
"entity": "Carbimazole",
"category": "Medicine"
}
],
"input": "$input"
}
"""
def __init__(
self,
spg_type_names: List[SPGTypeName],
language: str = "zh",
**kwargs
):
super().__init__(
spg_type_names=spg_type_names,
language=language,
**kwargs
)
self._render()
Returns:
List[SpgTypeEnum]: List of accepted SPG types.
"""
return [
SpgTypeEnum.Entity,
SpgTypeEnum.Concept,
SpgTypeEnum.Event,
]
def build_prompt(self, variables: Dict[str, str]) -> str:
schema = {}
for tmpSchema in self.schema_list:
schema.update(tmpSchema)
"""
Builds the prompt using the provided variables.
return super().build_prompt({"schema": schema, "input": variables.get("input")})
Args:
variables (Dict[str, str]): Dictionary of variables to be used in the prompt.
Returns:
str: The built prompt.
"""
return super().build_prompt(
{
"schema": copy.deepcopy(self.prompt_schema),
"input": variables.get("input"),
}
)
def process_property_name(self, name: str):
"""
Process property name by removing descriptions enclosed in parentheses.
Args:
name (dict): property names (possibly containing descriptions in parentheses)
Returns:
str: A new string having the descriptions in parentheses removed.
Example:
>>> name = 'authors(authors of work, such as director, actor, lyricist, composer and singer)'
>>> process_property_name(input_properties)
'authors'
"""
return name.split("(")[0]
def process_property_names(self, properties: Dict):
"""
Process property names by removing descriptions enclosed in parentheses.
This method iterates through the given dictionary of properties, removes any
descriptions enclosed in parentheses from the property names, and returns a new
dictionary with the processed names. If a property value is itself a dictionary,
this method will recursively process it.
Args:
properties (dict): A dictionary where keys are property names (possibly containing
descriptions in parentheses) and values are either property values
or nested dictionaries.
Returns:
dict: A new dictionary with the same structure as the input, but with all property
names having their descriptions in parentheses removed.
Example:
>>> input_properties = {
... "authors(authors of work, such as director, actor, lyricist, composer and singer)": "John Doe"
... }
>>> process_property_names(input_properties)
{'authors': 'John Doe'}
"""
output = {}
for k, v in properties.items():
k = self.process_property_name(k)
if isinstance(v, dict):
output[k] = self.process_property_names(v)
else:
output[k] = v
return output
def parse_response(self, response: str, **kwargs) -> List[SPGRecord]:
"""
Parses the response string into a list of SPG records.
Args:
response (str): The response string to be parsed.
**kwargs: Additional keyword arguments.
Returns:
List[SPGRecord]: List of parsed SPG records.
"""
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
entities = rsp["named_entities"]
else:
entities = rsp
return entities
def _render(self):
spo_list = []
for type_name, spg_type in self.spg_types.items():
if spg_type.spg_type_enum not in [SpgTypeEnum.Entity, SpgTypeEnum.Concept, SpgTypeEnum.Event]:
outputs = []
for item in rsp:
if "category" not in item or item["category"] not in self.schema:
continue
constraint = {}
properties = {}
properties.update(
{
v.name: (f"{v.name_zh}" if not v.desc else f"{v.name_zh}{v.desc}") if self.language == "zh" else (f"{v.name}" if not v.desc else f"{v.name}, {v.desc}")
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
}
)
properties.update(
{
f"{v.name}#{v.object_type_name_en}": (
f"{v.name_zh},类型是{v.object_type_name_zh}"
if not v.desc
else f"{v.name_zh}{v.desc},类型是{v.object_type_name_zh}"
) if self.language == "zh" else (
f"{v.name}, the type is {v.object_type_name_en}"
if not v.desc
else f"{v.name}{v.desc}, the type is {v.object_type_name_en}"
)
for k, v in spg_type.relations.items()
if not v.is_dynamic and k not in self.ignored_relations
}
)
constraint.update({"properties": properties})
spo_list.append({type_name: constraint})
properties = item.get("properties", {})
if "name" not in properties:
continue
output = {}
output["category"] = item["category"]
output["name"] = properties.pop("name")
output["properties"] = self.process_property_names(properties)
outputs.append(output)
return outputs
self.schema_list = spo_list
def create_prompt_schema(self):
"""
Creates the schema for extraction prompt based on the project schema.
"""
prompt_schema = []
accept_types = self.get_accept_types()
for type_name, spg_type in self.spg_types.items():
if type_name in self.ignored_types:
continue
if spg_type.spg_type_enum not in accept_types:
continue
type_desc = spg_type.desc
properties = copy.deepcopy(self.default_properties)
for k, v in spg_type.properties.items():
if k in self.ignored_properties or k in self.default_properties:
continue
multi_value = ConstraintTypeEnum.MultiValue.value in v.constraint
obj_type_name = v.object_type_name.split(".")[-1]
if multi_value:
obj_type_name = f"List[{obj_type_name}]"
if v.desc:
v_name = f"{v.name}({v.desc})"
else:
v_name = v.name
properties[v_name] = obj_type_name
for k, v in spg_type.relations.items():
if k in self.ignored_relations or k in self.default_properties:
continue
if v.name in properties:
continue
obj_type_name = v.object_type_name.split(".")[-1]
if v.desc:
v_name = f"{v.name}({v.desc})"
else:
v_name = v.name
properties[v_name] = obj_type_name
if type_desc:
prompt_schema.append(
{f"{type_name}({type_desc})": {"properties": properties}}
)
else:
prompt_schema.append({type_name: {"properties": properties}})
self.prompt_schema = prompt_schema
@PromptABC.register("spg_entity")
class SPGEntityPrompt(SPGPrompt):
template_zh: dict = {
"instruction": "作为一个图谱知识抽取的专家, 你需要基于定义了实体类型及对应属性的schema从input字段的文本中抽取出所有的实体及其属性schema中标记为List的属性返回list未能提取的属性返回null。以标准json list格式输出list中每个元素形如{category: properties}你可以参考example字段中给出的示例格式。注意实体属性的SemanticType指的是一个相比实体类型更具体且明确定义的类型例如Person类型的SemanticType可以是Professor或Actor。",
"example": [
{
"input": "周杰伦Jay Chou1979年1月18日出生于台湾省新北市祖籍福建省永春县华语流行乐男歌手、音乐人、演员、导演、编剧毕业于淡江中学。2000年发行个人首张音乐专辑《Jay》 [26]。2023年凭借《最伟大的作品》获得第一届浪潮音乐大赏年度制作、最佳作曲、最佳音乐录影带三项大奖。",
"output": [
{
"category": "Person",
"properties": {
"name": "周杰伦",
"semanticType": "Musician",
"description": "华语流行乐男歌手、音乐人、演员、导演、编剧",
},
},
{
"category": "GeographicLocation",
"properties": {
"name": "台湾省新北市",
"semanticType": "City",
"description": "周杰伦的出生地",
},
},
{
"category": "GeographicLocation",
"properties": {
"name": "福建省永春县",
"semanticType": "County",
"description": "周杰伦的祖籍",
},
},
{
"category": "Organization",
"properties": {
"name": "淡江中学",
"semanticType": "School",
"description": "周杰伦的毕业学校",
},
},
{
"category": "Works",
"properties": {
"name": "Jay",
"semanticType": "Album",
"description": "周杰伦的个人首张音乐专辑",
},
},
{
"category": "Works",
"properties": {
"name": "最伟大的作品",
"semanticType": "MusicVideo",
"description": "周杰伦凭借此作品获得多项音乐大奖",
},
},
],
}
],
}
template_en: dict = {
"instruction": "As an expert in graph knowledge extraction, you need to extract all entities and their properties from the text in the input field based on a schema that defines entity types and their corresponding attributes. Attributes marked as List in the schema should return a list, and attributes not extracted should return null. Output the results in a standard JSON list format, where each element in the list is in the form of {category: properties}. You can refer to the example format provided in the example field. Note that the SemanticType of an entity attribute refers to a more specific and clearly defined type compared to the entity type itself, such as Professor or Actor for the Person type.",
"example": [
{
"input": "Jay Chou, born on January 18, 1979, in New Taipei City, Taiwan Province, with ancestral roots in Yongchun County, Fujian Province, is a renowned male singer, musician, actor, director, and screenwriter in the realm of Chinese pop music. He graduated from Tamkang University. In 2000, he released his debut solo album, <Jay> [26]. In 2023, he was honored with three major awards at the inaugural Wave Music Awards for Best Production, Best Composition, and Best Music Video for his album The Greatest Work.",
"output": [
{
"category": "Person",
"properties": {
"name": "Jay Chou",
"semanticType": "Musician",
"description": "renowned male singer, musician, actor, director, and screenwriter in the realm of Chinese pop music",
},
},
{
"category": "GeographicLocation",
"properties": {
"name": "New Taipei City, Taiwan Province",
"semanticType": "City",
"description": "Jay Chou's birthplace",
},
},
{
"category": "GeographicLocation",
"properties": {
"name": "Yongchun County, Fujian Province",
"semanticType": "County",
"description": "Jay Chou's ancestral roots",
},
},
{
"category": "Organization",
"properties": {
"name": "Tamkang University",
"semanticType": "University",
"description": "Jay Chou's alma mater",
},
},
{
"category": "Works",
"properties": {
"name": "Jay",
"semanticType": "Album",
"description": "Jay Chou's debut solo album",
},
},
{
"category": "Works",
"properties": {
"name": "The Greatest Work",
"semanticType": "Album",
"description": "Jay Chou's album for which he won multiple awards",
},
},
],
}
],
}
def get_accept_types(self):
return [
SpgTypeEnum.Entity,
SpgTypeEnum.Concept,
]
@PromptABC.register("spg_event")
class SPGEventPrompt(SPGPrompt):
template_zh: dict = {
"instruction": "作为一个知识图谱图谱事件抽取的专家, 你需要基于定义的事件类型及对应属性的schema从input字段的文本中抽取出所有的事件及其属性schema中标记为List的属性返回list未能提取的属性返回null。以标准json list格式输出list中每个元素形如{category: properties}你可以参考example字段中给出的示例格式。",
"example": {
"input": "1986年周星驰被调入无线电视台戏剧组同年他在单元情景剧《哥哥的女友》中饰演可爱活泼又略带羞涩的潘家伟这也是他第一次在情景剧中担任男主角之后他还在温兆伦、郭晋安等人主演的电视剧中跑龙套。",
"output": [
{
"category": "Event",
"properties": {
"name": "周星驰被调入无线电视台戏剧组",
"abstract": "1986年周星驰被调入无线电视台戏剧组。",
"subject": "周星驰",
"time": "1986年",
"location": "无线电视台",
"participants": [],
"semanticType": "调动",
},
},
{
"category": "Event",
"properties": {
"name": "周星驰在《哥哥的女友》中饰演潘家伟",
"abstract": "1986年周星驰在单元情景剧《哥哥的女友》中饰演可爱活泼又略带羞涩的潘家伟这也是他第一次在情景剧中担任男主角。",
"subject": "周星驰",
"time": "1986年",
"location": None,
"participants": [],
"semanticType": "演出",
},
},
{
"category": "Event",
"properties": {
"name": "周星驰跑龙套",
"abstract": "1986年周星驰在温兆伦、郭晋安等人主演的电视剧中跑龙套。",
"subject": "周星驰",
"time": "1986年",
"location": None,
"participants": ["温兆伦", "郭晋安"],
"semanticType": "演出",
},
},
],
},
}
template_en: dict = {
"instruction": "As an expert in knowledge graph event extraction, you need to extract all events and their attributes from the text in the input field based on the defined event types and corresponding attribute schema. For attributes marked as List in the schema, return them as a list, and for attributes that cannot be extracted, return null. Output in the standard JSON list format, with each element in the list having the form {category: properties}. You can refer to the example format provided in the example field.",
"example": {
"input": "In 1986, Stephen Chow was transferred to the drama department of Television Broadcasts Limited (TVB). In the same year, he played the role of Pan Jiawei, a lovable, lively, and slightly shy character, in the episodic situational comedy <My Brother's Girlfriend.> This was his first time taking on a lead role in a sitcom. Later, he also had minor roles in TV series starring actors such as Anthony Wong and Aaron Kwok.",
"output": [
{
"category": "Event",
"properties": {
"name": "Stephen Chow was transferred to the drama department of TVB",
"abstract": "In 1986, Stephen Chow was transferred to the drama department of Television Broadcasts Limited (TVB).",
"subject": "Stephen Chow",
"time": "1986",
"location": "Television Broadcasts Limited (TVB)",
"participants": [],
"semanticType": "调动",
},
},
{
"category": "Event",
"properties": {
"name": "Stephen Chow played Pan Jiawei in My Brother's Girlfriend",
"abstract": "In 1986, Stephen Chow played the role of Pan Jiawei, a lovable, lively, and slightly shy character, in the episodic situational comedy <My Brother's Girlfriend.> This was his first time taking on a lead role in a sitcom.",
"subject": "Stephen Chow",
"time": "1986",
"location": None,
"participants": [],
"semanticType": "演出",
},
},
{
"category": "Event",
"properties": {
"name": "Stephen Chow had minor roles in TV series",
"abstract": "Later, Stephen Chow also had minor roles in TV series starring actors such as Anthony Wong and Aaron Kwok.",
"subject": "Stephen Chow",
"time": None,
"location": None,
"participants": ["Anthony Wong", "Aaron Kwok"],
"semanticType": "演出",
},
},
],
},
}
def get_accept_types(self):
return [
SpgTypeEnum.Event,
]
@PromptABC.register("spg_relation")
class SPGRelationPrompt(SPGPrompt):
template_zh: dict = {
"instruction": "您是一位专门从事开放信息提取OpenIE的专家。schema定义了你需要关注的实体类型以及可选的用括号包围的类型解释entity_list是一组实体列表。请从input字段的文本中提取任何可能的[主语实体,主语实体类类型,谓语,宾语实体,宾语实体类型]五元组并按照JSON列表格式列出它们。请严格遵循以下要求\n1. 主语实体和宾语实体应至少有一个包含在entity_list实体列表但不要求都包含\n2. 主语和宾语实体类型必须是schema定义的类型否则无效\n3. 明确地将代词解析为对应名称,以保持清晰度。",
"example": {
"input": "1986年周星驰被调入无线电视台戏剧组同年他在单元情景剧《哥哥的女友》中饰演可爱活泼又略带羞涩的潘家伟这也是他第一次在情景剧中担任男主角之后他还在温兆伦、郭晋安等人主演的电视剧中跑龙套。",
"entity_list": [
{"name": "周星驰", "category": "Person"},
{"name": "无线电视台", "category": "Organization"},
{"name": "哥哥的女友", "category": "Works"},
{"name": "潘家伟", "category": "Person"},
{"name": "温兆伦", "category": "Person"},
{"name": "郭晋安", "category": "Person"},
],
"output": [
["周星驰", "Person", "被调入", "无线电视台", "Organization"],
["周星驰", "Person", "出演", "哥哥的女朋友", "Works"],
["周星驰", "Person", "饰演", "潘家伟", "Person"],
["周星驰", "Person", "共演", "温兆伦", "Person"],
["周星驰", "Person", "共演", "郭晋安", "Person"],
[
"周星驰",
"Person",
"跑龙套",
"温兆伦、郭晋安等人主演的电视剧",
"Works",
],
],
},
}
template_en: dict = {
"instruction": "You are an expert in Open Information Extraction (OpenIE). The schema defines the entity types you need to focus on, along with optional type explanations enclosed in parentheses. The entity_list is a set of entity lists. Please extract any possible [subject entity, subject entity class type, predicate, object entity, object entity type] quintuples from the text in the input field and list them in JSON list format. Please adhere strictly to the following requirements:1. At least one of the subject entity and object entity must appear in the entity_list.\n2. The subject and object entity types must be defined in the schema; otherwise, they are considered invalid.\n3.Resolve pronouns to their corresponding names explicitly to maintain clarity.",
"example": {
"input": "In 1986, Stephen Chow was transferred to the drama division of TVB; that same year, he played the cute, lively, and slightly shy Pan Jiawei in the situational drama 'My Brother's Girlfriend,' which was also his first time as the male lead in a situational drama; later, he also appeared as an extra in TV dramas starring Deric Wan, Roger Kwok, and others.",
"entity_list": [
{"name": "Stephen Chow", "category": "Person"},
{"name": "TVB", "category": "Organization"},
{"name": "My Brother's Girlfriend", "category": "Works"},
{"name": "Pan Jiawei", "category": "Person"},
{"name": "Deric Wan", "category": "Person"},
{"name": "Roger Kwok", "category": "Person"},
],
"output": [
["Stephen Chow", "Person", "was transferred to", "TVB", "Organization"],
[
"Stephen Chow",
"Person",
"starred in",
"My Brother's Girlfriend",
"Works",
],
["Stephen Chow", "Person", "played", "Pan Jiawei", "Person"],
["Stephen Chow", "Person", "co-starred with", "Deric Wan", "Person"],
["Stephen Chow", "Person", "co-starred with", "Roger Kwok", "Person"],
[
"Stephen Chow",
"Person",
"appeared as an extra in",
"TV dramas starring Deric Wan, Roger Kwok, and others",
"Works",
],
],
},
}
def get_accept_types(self):
"""
Returns the list of accepted SPG types.
Returns:
List[SpgTypeEnum]: List of accepted SPG types.
"""
return [
SpgTypeEnum.Entity,
SpgTypeEnum.Concept,
]
def build_prompt(self, variables: Dict[str, str]) -> str:
"""
Builds the prompt using the provided variables.
Args:
variables (Dict[str, str]): Dictionary of variables to be used in the prompt.
Returns:
str: The built prompt.
"""
schema = []
for item in self.prompt_schema:
schema.extend(item.keys())
return super().build_prompt(
{
"schema": schema,
"input": variables.get("input"),
}
)
def parse_response(self, response: str, **kwargs) -> List[SPGRecord]:
"""
Parses the response string into a list of SPG records.
Args:
response (str): The response string to be parsed.
**kwargs: Additional keyword arguments.
Returns:
List[SPGRecord]: List of parsed SPG records.
"""
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
outputs = []
for item in rsp:
if len(item) != 5:
continue
s_name, s_label, predicate, o_name, o_label = item
s_label = self.process_property_name(s_label)
o_label = self.process_property_name(o_label)
# force convert to OTHER_TYPE or just drop it?
if s_label not in self.schema:
s_label = OTHER_TYPE
if o_label not in self.schema:
o_label = OTHER_TYPE
outputs.append([s_name, s_label, predicate, o_name, o_label])
return outputs

View File

@ -10,20 +10,15 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Dict, List
from knext.common.base.runnable import Input, Output
from kag.interface.builder import ExtractorABC
from kag.interface import PromptABC
class UserDefinedExtractor(ExtractorABC):
@property
def input_types(self) -> Input:
return Dict[str, str]
def init_prompt_with_fallback(prompt_name, biz_scene):
try:
return PromptABC.from_config({"type": f"{biz_scene}_{prompt_name}"})
except Exception as e:
print(
f"fail to initialize prompts with biz scene {biz_scene}, fallback to default biz scene, info: {e}"
)
@property
def output_types(self) -> Output:
return Dict[str, str]
def invoke(self, input: Input, **kwargs) -> List[Output]:
return input
return PromptABC.from_config({"type": f"default_{prompt_name}"})

221
kag/builder/runner.py Normal file
View File

@ -0,0 +1,221 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import traceback
import logging
import threading
from typing import Dict
from tqdm import tqdm
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.registry import Registrable
from kag.common.utils import reset, bold, red, generate_hash_id
from kag.common.checkpointer import CheckpointerManager
from kag.interface import KAGBuilderChain, ScannerABC
from kag.builder.model.sub_graph import SubGraph
from concurrent.futures import ThreadPoolExecutor, as_completed
logger = logging.getLogger()
def str_abstract(value: str):
"""
Abstracts a string value by returning the base name if it is a file path, or the first 10 characters otherwise.
Args:
value (str): The string value to be abstracted.
Returns:
str: The abstracted string value.
"""
if os.path.exists(value):
return os.path.basename(value)
return value[:10]
def dict_abstract(value: Dict):
"""
Abstracts each value in a dictionary by converting it to a string and then abstracting the string.
Args:
value (Dict): The dictionary to be abstracted.
Returns:
Dict: The abstracted dictionary.
"""
output = {}
for k, v in value.items():
output[k] = str_abstract(str(v))
return output
def generate_hash_id_and_abstract(value):
hash_id = generate_hash_id(value)
if isinstance(value, dict):
abstract = dict_abstract(value)
else:
abstract = str_abstract(value)
return hash_id, abstract
class BuilderChainRunner(Registrable):
"""
A class that manages the execution of a KAGBuilderChain with parallel processing and checkpointing.
This class provides methods to initialize the runner, process input data, and manage checkpoints for tracking processed data.
"""
def __init__(
self,
scanner: ScannerABC,
chain: KAGBuilderChain,
num_chains: int = 2,
num_threads_per_chain: int = 8,
):
"""
Initializes the BuilderChainRunner instance.
Args:
scanner (ScannerABC): The source scanner to generate input data.
chain (KAGBuilderChain): The builder chain to process the input data.
num_chains (int, optional): The number of parallel threads to use, with each thread launching a builder chain instance. Defaults to 2.
num_threads_per_chain (int, optional): The number of parallel workers within a builder chain. Defaults to 8.
ckpt_dir (str, optional): The directory to store checkpoint files. Defaults to "./ckpt".
"""
self.scanner = scanner
self.chain = chain
self.num_chains = num_chains
self.num_threads_per_chain = num_threads_per_chain
self.ckpt_dir = KAG_PROJECT_CONF.ckpt_dir
self.checkpointer = CheckpointerManager.get_checkpointer(
{
"type": "txt",
"ckpt_dir": self.ckpt_dir,
"rank": self.scanner.sharding_info.get_rank(),
"world_size": self.scanner.sharding_info.get_world_size(),
}
)
self.processed_chunks = CheckpointerManager.get_checkpointer(
{
"type": "zodb",
"ckpt_dir": os.path.join(self.ckpt_dir, "chain"),
"rank": self.scanner.sharding_info.get_rank(),
"world_size": self.scanner.sharding_info.get_world_size(),
}
)
self._local = threading.local()
def invoke(self, input):
"""
Processes the input data using the builder chain in parallel and manages checkpoints.
Args:
input: The input data to be processed.
"""
# def process(thread_local, chain_conf, data, data_id, data_abstract):
# try:
# if not hasattr(thread_local, "chain"):
# if chain_conf:
# thread_local.chain = KAGBuilderChain.from_config(chain_conf)
# else:
# thread_local.chain = self.chain
# result = thread_local.chain.invoke(
# data, max_workers=self.num_threads_per_chain
# )
# return data, data_id, data_abstract, result
# except Exception:
# traceback.print_exc()
# return None
def process(data, data_id, data_abstract):
try:
result = self.chain.invoke(
data,
max_workers=self.num_threads_per_chain,
processed_chunk_keys=self.processed_chunks.keys(),
)
return data, data_id, data_abstract, result
except Exception:
traceback.print_exc()
return None
futures = []
print(f"Processing {input}")
success = 0
try:
with ThreadPoolExecutor(self.num_chains) as executor:
for item in self.scanner.generate(input):
item_id, item_abstract = generate_hash_id_and_abstract(item)
if self.checkpointer.exists(item_id):
continue
fut = executor.submit(
process,
item,
item_id,
item_abstract,
)
futures.append(fut)
success = 0
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Progress",
position=0,
):
result = future.result()
if result is not None:
item, item_id, item_abstract, chain_output = result
info = {}
num_nodes = 0
num_edges = 0
num_subgraphs = 0
for item in chain_output:
if isinstance(item, SubGraph):
num_nodes += len(item.nodes)
num_edges += len(item.edges)
num_subgraphs += 1
elif isinstance(item, dict):
for k, v in item.items():
self.processed_chunks.write_to_ckpt(k, k)
if isinstance(v, SubGraph):
num_nodes += len(v.nodes)
num_edges += len(v.edges)
num_subgraphs += 1
info = {
"num_nodes": num_nodes,
"num_edges": num_edges,
"num_subgraphs": num_subgraphs,
}
self.checkpointer.write_to_ckpt(
item_id, {"abstract": item_abstract, "graph_stat": info}
)
success += 1
except:
traceback.print_exc()
CheckpointerManager.close()
msg = (
f"{bold}{red}Done process {len(futures)} records, with {success} successfully processed and {len(futures)-success} failures encountered.\n"
f"The log file is located at {self.checkpointer._ckpt_file_path}. "
f"Please access this file to obtain detailed task statistics.{reset}"
)
print(msg)
BuilderChainRunner.register("base", as_default=True)(BuilderChainRunner)

View File

@ -9,4 +9,3 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

File diff suppressed because one or more lines are too long

View File

@ -1,184 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import importlib
import inspect
import os
import sys
from abc import ABC
from string import Template
from typing import List
BUILDER_PROMPT_PATH = "kag.builder.prompt"
SOLVER_PROMPT_PATH = "kag.solver.prompt"
class PromptOp(ABC):
"""
Provides a template for generating and parsing prompts related to specific business scenes.
Subclasses must implement the template strings for specific languages (English or Chinese)
and override the `template_variables` and `parse_response` methods.
"""
"""English template string"""
template_en: str = ""
"""Chinese template string"""
template_zh: str = ""
def __init__(self, language: str, **kwargs):
"""
Initializes the PromptOp instance with the selected language.
Args:
language (str): The language for the prompt, should be either "en" or "zh".
Raises:
AssertionError: If the provided language is not supported.
"""
assert language in ["en", "zh"], f"language[{language}] is not supported."
self.template = self.template_en if language == "en" else self.template_zh
self.language = language
self.template_variables_value = {}
if "project_id" in kwargs:
self.project_id = kwargs["project_id"]
@property
def template_variables(self) -> List[str]:
"""
Gets the list of template variables.
Must be implemented by subclasses.
Returns:
- List[str]: A list of template variable names.
Raises:
- NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `template_variables` method."
)
def process_template_string_to_avoid_dollar_problem(self, template_string):
new_template_str = template_string.replace('$', '$$')
for var in self.template_variables:
new_template_str = new_template_str.replace(f'$${var}', f'${var}')
return new_template_str
def build_prompt(self, variables) -> str:
"""
Build a prompt based on the template and provided variables.
This method replaces placeholders in the template with actual variable values.
If a variable is not provided, it defaults to an empty string.
Parameters:
- variables: A dictionary containing variable names and their corresponding values.
Returns:
- A string or list of strings, depending on the template content.
"""
self.template_variables_value = variables
template_string = self.process_template_string_to_avoid_dollar_problem(self.template)
template = Template(template_string)
return template.substitute(**variables)
def parse_response(self, response: str, **kwargs):
"""
Parses the response string.
Must be implemented by subclasses.
Parameters:
- response (str): The response string to be parsed.
Raises:
- NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `parse_response` method."
)
@classmethod
def load(cls, biz_scene: str, type: str):
"""
Dynamically loads the corresponding PromptOp subclass object based on the business scene and type.
Parameters:
- biz_scene (str): The name of the business scene.
- type (str): The type of prompt.
Returns:
- subclass of PromptOp: The loaded PromptOp subclass object.
Raises:
- ImportError: If the specified module or class does not exist.
"""
dir_paths = [
os.path.join(os.getenv("KAG_PROJECT_ROOT_PATH", ""), "builder", "prompt"),
os.path.join(os.getenv("KAG_PROJECT_ROOT_PATH", ""), "solver", "prompt"),
]
module_paths = [
'.'.join([BUILDER_PROMPT_PATH, biz_scene, type]),
'.'.join([SOLVER_PROMPT_PATH, biz_scene, type]),
'.'.join([BUILDER_PROMPT_PATH, 'default', type]),
'.'.join([SOLVER_PROMPT_PATH, 'default', type]),
]
def find_class_from_dir(dir, type):
sys.path.append(dir)
for root, dirs, files in os.walk(dir):
for file in files:
if file.endswith(".py") and file.startswith(f"{type}."):
module_name = file[:-3]
try:
module = importlib.import_module(module_name)
except ImportError:
continue
cls_found = find_class_from_module(module)
if cls_found:
return cls_found
return None
def find_class_from_module(module):
classes = inspect.getmembers(module, inspect.isclass)
for class_name, class_obj in classes:
import kag
if issubclass(class_obj, kag.common.base.prompt_op.PromptOp) and inspect.getmodule(class_obj) == module:
return class_obj
return None
for dir_path in dir_paths:
try:
cls_found = find_class_from_dir(dir_path, type)
if cls_found:
return cls_found
except ImportError:
continue
for module_path in module_paths:
try:
module = importlib.import_module(module_path)
cls_found = find_class_from_module(module)
if cls_found:
return cls_found
except ModuleNotFoundError:
continue
raise ValueError(f'Not support prompt with biz_scene[{biz_scene}] and type[{type}]')

View File

@ -1,5 +1,7 @@
import re
import json
import string
import traceback
from collections import Counter
@ -17,15 +19,16 @@ def normalize_answer(s):
Returns:
str: The standardized answer string.
"""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return ' '.join(text.split())
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return str(text).lower()
@ -52,10 +55,16 @@ def f1_score(prediction, ground_truth):
ZERO_METRIC = (0, 0, 0)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
if (
normalized_prediction in ["yes", "no", "noanswer"]
and normalized_prediction != normalized_ground_truth
):
return ZERO_METRIC
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
if (
normalized_ground_truth in ["yes", "no", "noanswer"]
and normalized_prediction != normalized_ground_truth
):
return ZERO_METRIC
prediction_tokens = normalized_prediction.split()
@ -78,35 +87,156 @@ def f1_score(prediction, ground_truth):
def exact_match_score(prediction, ground_truth):
"""
Calculates the exact match score between a predicted answer and the ground truth answer.
This function normalizes both the predicted answer and the ground truth answer before comparing them.
Normalization is performed to ensure that non-essential differences such as spaces and case are ignored.
Parameters:
prediction (str): The predicted answer string.
ground_truth (str): The ground truth answer string.
Returns:
int: 1 if the predicted answer exactly matches the ground truth answer, otherwise 0.
"""
return 1 if normalize_answer(prediction) == normalize_answer(ground_truth) else 0
def get_em_f1(prediction, gold):
"""
Calculates the Exact Match (EM) score and F1 score between the prediction and the gold standard.
This function evaluates the performance of a model in text similarity tasks by calculating the EM score and F1 score to measure the accuracy of the predictions.
Parameters:
prediction (str): The output predicted by the model.
gold (str): The gold standard output (i.e., the correct output).
Returns:
tuple: A tuple containing two floats, the EM score and the F1 score. The EM score represents the exact match accuracy, while the F1 score is a combination of precision and recall.
"""
em = exact_match_score(prediction, gold)
f1, precision, recall = f1_score(prediction, gold)
return float(em), f1
return float(em), f1
def compare_summarization_answers(
query,
answer1,
answer2,
*,
api_key="EMPTY",
base_url="http://127.0.0.1:38080/v1",
model="gpt-4o-mini",
language="English",
retries=3,
):
"""
Given a query and two answers, compare the answers with an LLM for Comprehensiveness, Diversity and Empowerment.
This function is adapted from LightRAG for evaluating GraphRAG and LightRAG in QFS (query-focused summarization)
tasks:
https://github.com/HKUDS/LightRAG/blob/45cea6e/examples/batch_eval.py
Parameters:
query (str): The query inputed to LLMs.
answer1 (str): Answer generated by an LLM.
answer2 (str): Answer generated by another LLM.
api_key (str): API key to use when invoke the evaluating LLM.
base_url (str): base url to use when invoke the evaluating LLM.
model (str): model name to use when invoke the evaluating LLM.
language (str): language of the explanation
retries (int): number of retries
Returns:
str: response content generated by the evaluating LLM.
"""
from openai import OpenAI
sys_prompt = """
---Role---
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
"""
prompt = f"""
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
- **Empowerment**: How well does the answer help the reader understand and make informed judgments about the topic?
For each criterion, give each answer a score between 0 and 10, choose the better answer (either Answer 1 or Answer 2) and explain why.
Then, give each answer an overall score between 0 and 10, and select an overall winner based on these three categories.
Here is the question:
{query}
Here are the two answers:
**Answer 1:**
{answer1}
**Answer 2:**
{answer2}
Evaluate both answers using the three criteria listed above and provide detailed explanations for each criterion.
Output your evaluation in the following JSON format:
{{
"Comprehensiveness": {{
"Score 1": [Score of Answer 1 - an integer between 0 and 10],
"Score 2": [Score of Answer 2 - an integer between 0 and 10],
"Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation in {language} here]"
}},
"Diversity": {{
"Score 1": [Score of Answer 1 - an integer between 0 and 10],
"Score 2": [Score of Answer 2 - an integer between 0 and 10],
"Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation in {language} here]"
}},
"Empowerment": {{
"Score 1": [Score of Answer 1 - an integer between 0 and 10],
"Score 2": [Score of Answer 2 - an integer between 0 and 10],
"Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Provide explanation in {language} here]"
}},
"Overall": {{
"Score 1": [Score of Answer 1 - an integer between 0 and 10],
"Score 2": [Score of Answer 2 - an integer between 0 and 10],
"Winner": "[Answer 1 or Answer 2]",
"Explanation": "[Summarize why this answer is the overall winner based on the three criteria in {language}]"
}}
}}
"""
for index in range(retries):
content = None
try:
client = OpenAI(api_key=api_key, base_url=base_url)
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": prompt},
],
)
content = response.choices[0].message.content
if content.startswith("```json") and content.endswith("```"):
content = content[7:-3]
metrics = json.loads(content)
return metrics
except Exception:
if index == retries - 1:
message = (
f"Comparing summarization answers failed.\n"
f"query: {query}\n"
f"answer1: {answer1}\n"
f"answer2: {answer2}\n"
f"content: {content}\n"
f"exception:\n{traceback.format_exc()}"
)
print(message)
return None

View File

@ -1,22 +1,25 @@
from typing import List
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from .evaUtils import get_em_f1
from .evaUtils import compare_summarization_answers
class Evaluate():
class Evaluate:
"""
provide evaluation for benchmarks, such as emf1answer_similarity, answer_correctness
"""
def __init__(self, embedding_factory = "text-embedding-ada-002"):
def __init__(self, embedding_factory="text-embedding-ada-002"):
self.embedding_factory = embedding_factory
def evaForSimilarity(self, predictionlist: List[str], goldlist: List[str]):
"""
evaluate the similarity between prediction and gold #TODO
"""
# data_samples = {
# data_samples = {
# 'question': [],
# 'answer': predictionlist,
# 'ground_truth': goldlist
@ -29,7 +32,6 @@ class Evaluate():
# return np.average(score.to_pandas()[['answer_similarity']])
return 0.0
def getBenchMark(self, predictionlist: List[str], goldlist: List[str]):
"""
Calculates and returns evaluation metrics between predictions and ground truths.
@ -45,21 +47,113 @@ class Evaluate():
dict: Dictionary containing EM, F1 score, and answer similarity.
"""
# Initialize total metrics
total_metrics = {'em': 0.0, 'f1': 0.0, 'answer_similarity': 0.0}
total_metrics = {"em": 0.0, "f1": 0.0, "answer_similarity": 0.0}
# Iterate over prediction and gold lists to calculate EM and F1 scores
for prediction, gold in zip(predictionlist, goldlist):
em, f1 = get_em_f1(prediction, gold) # Call external function to calculate EM and F1
total_metrics['em'] += em # Accumulate EM score
total_metrics['f1'] += f1 # Accumulate F1 score
em, f1 = get_em_f1(
prediction, gold
) # Call external function to calculate EM and F1
total_metrics["em"] += em # Accumulate EM score
total_metrics["f1"] += f1 # Accumulate F1 score
# Calculate average EM and F1 scores
total_metrics['em'] /= len(predictionlist)
total_metrics['f1'] /= len(predictionlist)
total_metrics["em"] /= len(predictionlist)
total_metrics["f1"] /= len(predictionlist)
# Call method to calculate answer similarity
total_metrics['answer_similarity'] = self.evaForSimilarity(predictionlist, goldlist)
total_metrics["answer_similarity"] = self.evaForSimilarity(
predictionlist, goldlist
)
# Return evaluation metrics dictionary
return total_metrics
def getSummarizationMetrics(
self,
queries: List[str],
answers1: List[str],
answers2: List[str],
*,
api_key="EMPTY",
base_url="http://127.0.0.1:38080/v1",
model="gpt-4o-mini",
language="English",
retries=3,
max_workers=50,
):
"""
Calculates and returns QFS (query-focused summarization) evaluation metrics
for the given queries, answers1 and answers2.
This function evaluates the triple (query, answer1, answer2) by feeding it
into an evaluating LLM specified as `api_key`, `base_url` and `model`.
Parameters:
queries (List[str]): List of queries.
answers1 (List[str]): List of answers generated by an LLM (LLM-1).
answers2 (List[str]): List of answers generated by another LLM (LLM-2).
api_key (str): API key to use when invoke the evaluating LLM.
base_url (str): base url to use when invoke the evaluating LLM.
model (str): model name to use when invoke the evaluating LLM.
language (str): language of the explanation
retries (int): number of retries
max_workers (int): number of workers
Returns:
dict: Dictionary containing the average metrics and the responses
generated by the evaluating LLM.
"""
responses = [None] * len(queries)
all_keys = "Comprehensiveness", "Diversity", "Empowerment", "Overall"
all_items = "Score 1", "Score 2"
average_metrics = {key: {item: 0.0 for item in all_items} for key in all_keys}
success_count = 0
def process_sample(index, query, answer1, answer2):
metrics = compare_summarization_answers(
query,
answer1,
answer2,
api_key=api_key,
base_url=base_url,
model=model,
language=language,
retries=retries,
)
if metrics is None:
print(
f"fail to compare answers of query {index + 1}.\n"
f" query: {query}\n"
f" answer1: {answer1}\n"
f" answer2: {answer2}\n"
)
else:
responses[index] = metrics
return metrics
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(process_sample, index, query, answer1, answer2)
for index, (query, answer1, answer2) in enumerate(
zip(queries, answers1, answers2)
)
]
for future in tqdm(
as_completed(futures), total=len(futures), desc="Evaluating: "
):
metrics = future.result()
if metrics is not None:
for key in all_keys:
for item in all_items:
average_metrics[key][item] += metrics[key][item]
success_count += 1
if success_count > 0:
for key in all_keys:
for item in all_items:
average_metrics[key][item] /= success_count
result = {
"average_metrics": average_metrics,
"responses": responses,
}
return result

View File

@ -9,15 +9,9 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.common.checkpointer.base import CheckPointer, CheckpointerManager
from kag.common.checkpointer.txt_checkpointer import TxtCheckPointer
from kag.common.checkpointer.bin_checkpointer import BinCheckPointer
from kag.common.llm.config.openai import OpenAIConfig
from kag.common.llm.config.base import LLMConfig
from kag.common.llm.config.vllm import VLLMConfig
from kag.common.llm.config.ollama import OllamaConfig
__all__ = [
"OpenAIConfig",
"LLMConfig",
"VLLMConfig",
"OllamaConfig"
]
__all__ = ["CheckPointer", "CheckpointerManager", "TxtCheckPointer", "BinCheckPointer"]

View File

@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import threading
from kag.common.registry import Registrable
from kag.common.utils import reset, bold, red, generate_hash_id
class CheckPointer(Registrable):
"""
A class for managing checkpoints in a distributed environment.
This class provides methods to open, read, write, and close checkpoint files.
It is designed to handle checkpoints in a distributed setting, where multiple
processes may be writing checkpoints in parallel.
Attributes:
ckpt_file_name (str): The format string for checkpoint file names.
"""
ckpt_file_name = "kag_checkpoint_{}_{}.ckpt"
def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1):
"""
Initializes the CheckPointer with the given checkpoint directory, rank, and world size.
Args:
ckpt_dir (str): The directory where checkpoint files are stored.
rank (int): The rank of the current process (default is 0).
world_size (int): The total number of processes in the distributed environment (default is 1).
"""
self._ckpt_dir = ckpt_dir
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir, exist_ok=True)
self.rank = rank
self.world_size = world_size
self._ckpt_file_path = os.path.join(
self._ckpt_dir, CheckPointer.ckpt_file_name.format(rank, world_size)
)
self._ckpt = self.open()
self._closed = False
if self.size() > 0:
print(
f"{bold}{red}Existing checkpoint found in {self._ckpt_dir}, with {self.size()} records.{reset}"
)
def open(self):
"""
Opens the checkpoint file and returns the checkpoint object.
Returns:
Any: The checkpoint object, which can be used for reading and writing.
"""
raise NotImplementedError("open not implemented yet.")
def read_from_ckpt(self, key):
"""
Reads a value from the checkpoint file using the specified key.
Args:
key (str): The key to retrieve the value from the checkpoint.
Returns:
Any: The value associated with the key in the checkpoint.
"""
raise NotImplementedError("read_from_ckpt not implemented yet.")
def write_to_ckpt(self, key, value):
"""
Writes a value to the checkpoint file using the specified key.
Args:
key (str): The key to store the value in the checkpoint.
value (Any): The value to be stored in the checkpoint.
"""
raise NotImplementedError("write_to_ckpt not implemented yet.")
def _close(self):
"""
Closes the checkpoint file.
"""
raise NotImplementedError("close not implemented yet.")
def close(self):
"""
Closes the checkpoint file.
"""
if not self._closed:
self._close()
self._closed = True
def exists(self, key):
"""
Checks if a key exists in the checkpoint file.
Args:
key (str): The key to check for existence in the checkpoint.
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
raise NotImplementedError("close not implemented yet.")
def keys(self):
"""
Returns the key set contained in the checkpoint file.
Returns:
set: The key set contained in the checkpoint.
"""
raise NotImplementedError("keys not implemented yet.")
def size(self):
"""
Return the number of records in the checkpoint file.
Returns:
int: the number of records in the checkpoint file.
"""
raise NotImplementedError("size not implemented yet.")
def __contains__(self, key):
"""
Defines the behavior of the `in` operator for the object.
Args:
key (str): The key to check for existence in the checkpoint.
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
return self.exists(key)
class CheckpointerManager:
"""
Manages the lifecycle of CheckPointer objects.
This class provides a thread-safe mechanism to retrieve and close CheckPointer
instances based on a configuration. It uses a global dictionary to cache
CheckPointer objects, ensuring that each configuration corresponds to a unique
instance.
"""
_CKPT_OBJS = {}
_LOCK = threading.Lock()
@staticmethod
def get_checkpointer(config):
"""
Retrieves or creates a CheckPointer instance based on the provided configuration.
Args:
config (dict): The configuration used to initialize the CheckPointer.
Returns:
CheckPointer: A CheckPointer instance corresponding to the configuration.
"""
with CheckpointerManager._LOCK:
key = generate_hash_id(config)
if key not in CheckpointerManager._CKPT_OBJS:
ckpter = CheckPointer.from_config(config)
CheckpointerManager._CKPT_OBJS[key] = ckpter
return CheckpointerManager._CKPT_OBJS[key]
@staticmethod
def close():
"""
Closes all cached CheckPointer instances.
This method iterates through all cached CheckPointer objects and calls their
`close` method to release resources. After calling this method, the cache
will be cleared.
"""
with CheckpointerManager._LOCK:
for v in CheckpointerManager._CKPT_OBJS.values():
v.close()
CheckpointerManager._CKPT_OBJS.clear()

View File

@ -0,0 +1,217 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import shelve
import logging
import transaction
import threading
import pickle
import BTrees.OOBTree
from ZODB import DB
from ZODB.FileStorage import FileStorage
from kag.common.checkpointer.base import CheckPointer
logger = logging.getLogger()
@CheckPointer.register("bin")
class BinCheckPointer(CheckPointer):
"""
A subclass of CheckPointer that uses shelve for binary checkpoint management.
This class extends the CheckPointer class to provide binary checkpoint
management using the shelve module. It supports opening, reading, writing,
and closing checkpoint files in a binary format.
"""
def open(self):
"""
Opens the checkpoint file using shelve in writeback mode.
Returns:
Any: The shelve object representing the checkpoint file.
"""
return shelve.open(self._ckpt_file_path, "c", writeback=True)
def exists(self, key):
"""
Checks if a key exists in the checkpoint file.
Args:
key (str): The key to check for existence in the checkpoint.
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
return key in self._ckpt
def read_from_ckpt(self, key):
"""
Reads a value from the checkpoint file using the specified key.
Args:
key (str): The key to retrieve the value from the checkpoint.
Returns:
Any: The value associated with the key in the checkpoint.
"""
return self._ckpt[key]
def write_to_ckpt(self, key, value):
"""
Writes a value to the checkpoint file using the specified key.
Args:
key (str): The key to store the value in the checkpoint.
value (Any): The value to be stored in the checkpoint.
"""
self._ckpt[key] = value
self._ckpt.sync()
def _close(self):
"""
Closes the checkpoint file and ensures data is written to disk.
"""
self._ckpt.sync()
self._ckpt.close()
def size(self):
"""
Returns the number of entries in the checkpoint.
Returns:
int: The number of entries in the checkpoint.
"""
return len(self._ckpt)
def keys(self):
return set(self._ckpt.keys())
@CheckPointer.register("zodb")
class ZODBCheckPointer(CheckPointer):
"""
A CheckPointer implementation that uses ZODB as the underlying storage.
This class provides methods to open, read, write, and close checkpoints using ZODB.
"""
def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1):
"""
Initializes the ZODBCheckPointer with the given checkpoint directory, rank, and world size.
Args:
ckpt_dir (str): The directory where checkpoint files are stored.
rank (int): The rank of the current process (default is 0).
world_size (int): The total number of processes in the distributed environment (default is 1).
"""
self._lock = threading.Lock()
super().__init__(ckpt_dir, rank, world_size)
def open(self):
"""
Opens the ZODB database and returns the root object for checkpoint storage.
Returns:
dict: The root object of the ZODB database, which is a dictionary-like object.
"""
with self._lock:
storage = FileStorage(self._ckpt_file_path)
db = DB(storage)
with db.transaction() as conn:
if not hasattr(conn.root, "data"):
conn.root.data = BTrees.OOBTree.BTree()
return db
def read_from_ckpt(self, key):
"""
Reads a value from the checkpoint using the specified key.
Args:
key (str): The key to retrieve the value from the checkpoint.
Returns:
Any: The value associated with the key in the checkpoint.
"""
with self._lock:
with self._ckpt.transaction() as conn:
obj = conn.root.data.get(key, None)
if obj:
return pickle.loads(obj)
else:
return None
def write_to_ckpt(self, key, value):
"""
Writes a value to the checkpoint using the specified key.
By default, ZODB tracks modifications to the written object (value) and
continuously synchronizes these changes to the storage. For example, if
the value is a `SubGraph` object, subsequent modifications to its
attributes will be synchronized, which is not what we expect.
Therefore, we use `pickle` to serialize the value object before writing it,
ensuring that the object behaves as an immutable object.
Args:
key (str): The key to store the value in the checkpoint.
value (Any): The value to be stored in the checkpoint.
"""
with self._lock:
try:
with self._ckpt.transaction() as conn:
conn.root.data[key] = pickle.dumps(value)
except Exception as e:
logger.warn(f"failed to write checkpoint {key} to db, info: {e}")
def _close(self):
"""
Closes the ZODB database connection.
"""
with self._lock:
try:
transaction.commit()
except:
transaction.abort()
if self._ckpt is not None:
self._ckpt.close()
def exists(self, key):
"""
Checks if a key exists in the checkpoint.
Args:
key (str): The key to check for existence in the checkpoint.
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
with self._lock:
with self._ckpt.transaction() as conn:
return key in conn.root.data
def size(self):
"""
Returns the number of entries in the checkpoint.
This method calculates the size of the checkpoint by counting the number
of keys stored in the checkpoint's data dictionary. It ensures thread-safe
access to the checkpoint by using a lock.
Returns:
int: The number of entries in the checkpoint.
"""
with self._lock:
with self._ckpt.transaction() as conn:
return len(conn.root.data)
def keys(self):
with self._lock:
with self._ckpt.transaction() as conn:
return set(conn.root.data.keys())

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import json
from kag.common.checkpointer.base import CheckPointer
@CheckPointer.register("txt")
class TxtCheckPointer(CheckPointer):
"""
A subclass of CheckPointer that uses a text file for checkpoint management.
This class extends the CheckPointer class to provide checkpoint management
using a text file. It supports opening, reading, writing, and closing
checkpoint files in a text format. Each checkpoint entry is stored as a
JSON object in the file.
"""
def open(self):
"""
Opens the checkpoint file and loads existing data into a dictionary.
Returns:
dict: A dictionary containing the checkpoint data.
"""
ckpt = {}
if os.path.exists(self._ckpt_file_path):
with open(self._ckpt_file_path, "r") as reader:
for line in reader:
data = json.loads(line)
ckpt[data["id"]] = data["value"]
self._writer = open(self._ckpt_file_path, "a")
return ckpt
def exists(self, key):
"""
Checks if a key exists in the checkpoint file.
Args:
key (str): The key to check for existence in the checkpoint.
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
return key in self._ckpt
def read_from_ckpt(self, key):
"""
Reads a value from the checkpoint file using the specified key.
Args:
key (str): The key to retrieve the value from the checkpoint.
Returns:
Any: The value associated with the key in the checkpoint.
"""
return self._ckpt[key]
def write_to_ckpt(self, key, value):
"""
Writes a value to the checkpoint file using the specified key.
Args:
key (str): The key to store the value in the checkpoint.
value (Any): The value to be stored in the checkpoint.
"""
self._ckpt[key] = value
self._writer.write(json.dumps({"id": key, "value": value}, ensure_ascii=False))
self._writer.write("\n")
self._writer.flush()
def _close(self):
"""
Closes the checkpoint file and ensures data is written to disk.
"""
self._writer.flush()
self._writer.close()
def size(self):
return len(self._ckpt)
def keys(self):
return set(self._ckpt.keys())

202
kag/common/conf.py Normal file
View File

@ -0,0 +1,202 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import copy
import os
import logging
import yaml
import json
import pprint
from pathlib import Path
from typing import Union, Optional
from knext.project.client import ProjectClient
class KAGConstants(object):
LOCAL_SCHEMA_URL = "http://localhost:8887"
DEFAULT_KAG_CONFIG_FILE_NAME = "default_config.yaml"
KAG_CONFIG_FILE_NAME = "kag_config.yaml"
DEFAULT_KAG_CONFIG_PATH = os.path.join(__file__, DEFAULT_KAG_CONFIG_FILE_NAME)
KAG_CFG_PREFIX = "KAG"
GLOBAL_CONFIG_KEY = "global"
PROJECT_CONFIG_KEY = "project"
KAG_NAMESPACE_KEY = "namespace"
KAG_PROJECT_ID_KEY = "id"
KAG_PROJECT_HOST_ADDR_KEY = "host_addr"
KAG_LANGUAGE_KEY = "language"
KAG_CKPT_DIR_KEY = "checkpoint_path"
KAG_BIZ_SCENE_KEY = "biz_scene"
ENV_KAG_PROJECT_ID = "KAG_PROJECT_ID"
ENV_KAG_PROJECT_HOST_ADDR = "KAG_PROJECT_HOST_ADDR"
ENV_KAG_DEBUG_DUMP_CONFIG = "KAG_DEBUG_DUMP_CONFIG"
KAG_SIMILAR_EDGE_NAME = "similar"
KS8_ENV_TF_CONFIG = "TF_CONFIG"
K8S_ENV_MASTER_ADDR = "MASTER_ADDR"
K8S_ENV_MASTER_PORT = "MASTER_PORT"
K8S_ENV_WORLD_SIZE = "WORLD_SIZE"
K8S_ENV_RANK = "RANK"
K8S_ENV_POD_NAME = "POD_NAME"
class KAGGlobalConf:
def __init__(self):
self._extra = {}
def initialize(self, **kwargs):
self.project_id = kwargs.pop(
KAGConstants.KAG_PROJECT_ID_KEY,
os.getenv(KAGConstants.ENV_KAG_PROJECT_ID, "1"),
)
self.host_addr = kwargs.pop(
KAGConstants.KAG_PROJECT_HOST_ADDR_KEY,
os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR, "http://127.0.0.1:8887"),
)
self.biz_scene = kwargs.pop(KAGConstants.KAG_BIZ_SCENE_KEY, "default")
self.language = kwargs.pop(KAGConstants.KAG_LANGUAGE_KEY, "en")
self.namespace = kwargs.pop(KAGConstants.KAG_NAMESPACE_KEY, None)
self.ckpt_dir = kwargs.pop(KAGConstants.KAG_CKPT_DIR_KEY, "ckpt")
# process configs set to class attr directly
for k in self._extra.keys():
if hasattr(self, k):
delattr(self, k)
for k, v in kwargs.items():
setattr(self, k, v)
self._extra = kwargs
print(
f"Done initialize project config with host addr {self.host_addr} and project_id {self.project_id}"
)
def _closest_cfg(
path: Union[str, os.PathLike] = ".",
prev_path: Optional[Union[str, os.PathLike]] = None,
) -> str:
"""
Return the path to the closest .kag.cfg file by traversing the current
directory and its parents
"""
if prev_path is not None and str(path) == str(prev_path):
return ""
path = Path(path).resolve()
cfg_file = path / KAGConstants.KAG_CONFIG_FILE_NAME
if cfg_file.exists():
return str(cfg_file)
return _closest_cfg(path.parent, path)
def load_config(prod: bool = False):
"""
Get kag config file as a ConfigParser.
"""
if prod:
project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID)
host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR)
project_client = ProjectClient(host_addr=host_addr)
project = project_client.get_by_id(project_id)
config = json.loads(project.config)
if "project" not in config:
config["project"] = {
KAGConstants.KAG_PROJECT_ID_KEY: project_id,
KAGConstants.KAG_PROJECT_HOST_ADDR_KEY: host_addr,
KAGConstants.KAG_NAMESPACE_KEY: project.namespace,
}
prompt_config = config.pop("prompt", {})
for key in [KAGConstants.KAG_LANGUAGE_KEY, KAGConstants.KAG_BIZ_SCENE_KEY]:
if key in prompt_config:
config["project"][key] = prompt_config[key]
if "vectorizer" in config and "vectorize_model" not in config:
config["vectorize_model"] = config["vectorizer"]
return config
else:
config_file = _closest_cfg()
if os.path.exists(config_file) and os.path.isfile(config_file):
print(f"found config file: {config_file}")
with open(config_file, "r") as reader:
config = reader.read()
return yaml.safe_load(config)
else:
return {}
class KAGConfigMgr:
def __init__(self):
self.config = {}
self.global_config = KAGGlobalConf()
self._is_initialized = False
def init_log_config(self, config):
log_conf = config.get("log", {})
if log_conf:
log_level = log_conf.get("level", "INFO")
else:
log_level = "INFO"
logging.basicConfig(level=logging.getLevelName(log_level))
logging.getLogger("neo4j.notifications").setLevel(logging.ERROR)
logging.getLogger("neo4j.io").setLevel(logging.INFO)
logging.getLogger("neo4j.pool").setLevel(logging.INFO)
def initialize(self, prod: bool = True):
config = load_config(prod)
if self._is_initialized:
print(
"Reinitialize the KAG configuration, an operation that should exclusively be triggered within the Java invocation context."
)
print(f"original config: {self.config}")
print(f"new config: {config}")
self.prod = prod
self.config = config
global_config = self.config.get(KAGConstants.PROJECT_CONFIG_KEY, {})
self.global_config.initialize(**global_config)
self.init_log_config(self.config)
self._is_initialized = True
@property
def all_config(self):
return copy.deepcopy(self.config)
KAG_CONFIG = KAGConfigMgr()
KAG_PROJECT_CONF = KAG_CONFIG.global_config
def init_env():
project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID)
host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR)
if project_id and host_addr:
prod = True
else:
prod = False
global KAG_CONFIG
KAG_CONFIG.initialize(prod)
if prod:
msg = "Done init config from server"
else:
msg = "Done init config from local file"
os.environ[KAGConstants.ENV_KAG_PROJECT_ID] = str(KAG_PROJECT_CONF.project_id)
os.environ[KAGConstants.ENV_KAG_PROJECT_HOST_ADDR] = str(KAG_PROJECT_CONF.host_addr)
if len(KAG_CONFIG.all_config) > 0:
dump_flag = os.getenv(KAGConstants.ENV_KAG_DEBUG_DUMP_CONFIG)
if dump_flag is not None and dump_flag.strip() == "1":
print(f"{msg}:")
pprint.pprint(KAG_CONFIG.all_config, indent=2)
else:
print(
f"{msg}: set {KAGConstants.ENV_KAG_DEBUG_DUMP_CONFIG}=1 to dump config"
)
else:
print("No config found.")

View File

@ -1,33 +0,0 @@
[project]
with_server = True
host_addr = http://127.0.0.1:8887
[vectorizer]
vectorizer = kag.common.vectorizer.OpenAIVectorizer
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
client_type = ollama
base_url = http://localhost:11434/api/generate
model = llama3.1
[indexer]
with_semantic = False
similarity_threshold = 0.8
[retriever]
with_semantic = False
pagerank_threshold = 0.9
match_threshold = 0.8
top_k = 10
[schedule]
interval_minutes = -1
[log]
level = INFO

View File

@ -1,117 +1,145 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import os
import sys
from configparser import ConfigParser as CP
from pathlib import Path
from typing import Union, Optional
import kag.common as common
class ConfigParser(CP):
def __init__(self,defaults=None):
CP.__init__(self,defaults=defaults)
def optionxform(self, optionstr):
return optionstr
import json
import time
import datetime
import socket
import traceback
from kag.common.conf import KAGConstants
LOCAL_SCHEMA_URL = "http://localhost:8887"
DEFAULT_KAG_CONFIG_FILE_NAME = "default_config.cfg"
DEFAULT_KAG_CONFIG_PATH = os.path.join(common.__path__[0], DEFAULT_KAG_CONFIG_FILE_NAME)
KAG_CFG_PREFIX = "KAG"
def parse_tf_config():
tf_config_str = os.environ.get(KAGConstants.KS8_ENV_TF_CONFIG, None)
if tf_config_str is None:
return None
else:
return json.loads(tf_config_str)
def init_env():
"""Initialize environment to use command-line tool from inside a project
dir. This sets the Scrapy settings module and modifies the Python path to
be able to locate the project module.
"""
project_cfg, root_path = get_config()
init_kag_config(Path(root_path) / "kag_config.cfg")
def get_role_number(config, role_name):
role_info = config["cluster"].get(role_name, None)
if role_info is None:
return 0
else:
return len(role_info)
def get_config():
"""
Get kag config file as a ConfigParser.
"""
local_cfg_path = _closest_cfg()
local_cfg = ConfigParser()
local_cfg.read(local_cfg_path)
def get_rank(default=None):
if KAGConstants.K8S_ENV_RANK in os.environ:
return int(os.environ[KAGConstants.K8S_ENV_RANK])
projdir = ""
if local_cfg_path:
projdir = str(Path(local_cfg_path).parent)
if projdir not in sys.path:
sys.path.append(projdir)
tf_config = parse_tf_config()
if tf_config is None:
return default
return local_cfg, projdir
num_master = get_role_number(tf_config, "master")
task_type = tf_config["task"]["type"]
task_index = tf_config["task"]["index"]
if task_type == "master":
rank = task_index
elif task_type == "worker":
rank = num_master + task_index
else:
rank = default
return rank
def _closest_cfg(
path: Union[str, os.PathLike] = ".",
prev_path: Optional[Union[str, os.PathLike]] = None,
) -> str:
"""
Return the path to the closest .kag.cfg file by traversing the current
directory and its parents
"""
if prev_path is not None and str(path) == str(prev_path):
return ""
path = Path(path).resolve()
cfg_file = path / "kag_config.cfg"
if cfg_file.exists():
return str(cfg_file)
return _closest_cfg(path.parent, path)
def get_world_size(default=None):
if KAGConstants.K8S_ENV_WORLD_SIZE in os.environ:
return os.environ[KAGConstants.K8S_ENV_WORLD_SIZE]
tf_config = parse_tf_config()
if tf_config is None:
return default
num_master = get_role_number(tf_config, "master")
num_worker = get_role_number(tf_config, "worker")
return num_master + num_worker
def get_cfg_files():
"""
Get global and local kag config files and paths.
"""
local_cfg_path = _closest_cfg()
local_cfg = ConfigParser()
local_cfg.read(local_cfg_path)
if local_cfg_path:
projdir = str(Path(local_cfg_path).parent)
if projdir not in sys.path:
sys.path.append(projdir)
return local_cfg, local_cfg_path
def get_master_port(default=None):
return os.environ.get(KAGConstants.K8S_ENV_MASTER_PORT, default)
def get_master_addr(default=None):
if KAGConstants.K8S_ENV_MASTER_ADDR in os.environ:
return os.environ[KAGConstants.K8S_ENV_MASTER_ADDR]
def init_kag_config(config_path: Union[str, Path] = None):
if not config_path or isinstance(config_path, Path) and not config_path.exists():
config_path = DEFAULT_KAG_CONFIG_PATH
kag_cfg = ConfigParser()
kag_cfg.read(config_path)
os.environ["KAG_PROJECT_ROOT_PATH"] = os.path.abspath(os.path.dirname(config_path))
tf_config = parse_tf_config()
if tf_config is None:
return default
for section in kag_cfg.sections():
sec_cfg = {}
for key, value in kag_cfg.items(section):
item_cfg_key = f"{KAG_CFG_PREFIX}_{section}_{key}".upper()
os.environ[item_cfg_key] = value
sec_cfg[key] = value
sec_cfg_key = f"{KAG_CFG_PREFIX}_{section}".upper()
os.environ[sec_cfg_key] = str(sec_cfg)
if section == "log":
for key, value in kag_cfg.items(section):
if key == "level":
logging.basicConfig(level=logging.getLevelName(value))
# neo4j log level set to be default error
logging.getLogger("neo4j.notifications").setLevel(logging.ERROR)
logging.getLogger("neo4j.io").setLevel(logging.INFO)
logging.getLogger("neo4j.pool").setLevel(logging.INFO)
return tf_config["cluster"]["worker"][0]
def host2tensor(master_port):
import torch
host_str = socket.gethostbyname(socket.gethostname())
host = [int(x) for x in host_str.split(".")]
host.append(int(master_port))
host_tensor = torch.tensor(host)
return host_tensor
def tensor2host(host_tensor):
host_tensor = host_tensor.tolist()
host = ".".join([str(x) for x in host_tensor[0:4]])
port = host_tensor[4]
return f"{host}:{port}"
def sync_hosts():
import torch
import torch.distributed as dist
rank = get_rank()
if rank is None:
raise ValueError("can't get rank of container")
rank = int(rank)
world_size = get_world_size()
if world_size is None:
raise ValueError("can't get world_size of container")
world_size = int(world_size)
master_port = get_master_port()
if master_port is None:
raise ValueError("can't get master_port of container")
master_port = int(master_port)
while True:
try:
dist.init_process_group(
backend="gloo",
rank=rank,
world_size=world_size,
timeout=datetime.timedelta(days=1),
)
break
except Exception as e:
error_traceback = traceback.format_exc()
print(f"failed to init process group, info: {e}\n\n\n{error_traceback}")
time.sleep(60)
print("Done init process group, get all hosts...")
host_tensors = [torch.tensor([0, 0, 0, 0, 0]) for x in range(world_size)]
dist.all_gather(host_tensors, host2tensor(master_port))
# we need to destory torch process group to release MASTER_PORT, otherwise the server
# can't serving on it .
print("Done get all hosts, destory process group...")
dist.destroy_process_group()
time.sleep(10)
return [tensor2host(x) for x in host_tensors]
def extract_job_name_from_pod_name(pod_name):
if "-ptjob" in pod_name:
return pod_name.rsplit("-ptjob", maxsplit=1)[0]
elif "-tfjob" in pod_name:
return pod_name.rsplit("-tfjob", maxsplit=1)[0]
elif "-mpijob" in pod_name:
return pod_name.rsplit("-mpijob", maxsplit=1)[0]
else:
return None

View File

@ -49,7 +49,9 @@ class GraphStore(ABC):
pass
@abstractmethod
def upsert_nodes(self, label, properties_list, id_key="id", extra_labels=("Entity",)):
def upsert_nodes(
self, label, properties_list, id_key="id", extra_labels=("Entity",)
):
"""
Insert or update multiple nodes.
@ -112,10 +114,18 @@ class GraphStore(ABC):
pass
@abstractmethod
def upsert_relationship(self, start_node_label, start_node_id_value,
end_node_label, end_node_id_value,
rel_type, properties, upsert_nodes=True,
start_node_id_key="id", end_node_id_key="id"):
def upsert_relationship(
self,
start_node_label,
start_node_id_value,
end_node_label,
end_node_id_value,
rel_type,
properties,
upsert_nodes=True,
start_node_id_key="id",
end_node_id_key="id",
):
"""
Insert or update a relationship.
@ -133,9 +143,16 @@ class GraphStore(ABC):
pass
@abstractmethod
def upsert_relationships(self, start_node_label, end_node_label, rel_type,
relationships, upsert_nodes=True, start_node_id_key="id",
end_node_id_key="id"):
def upsert_relationships(
self,
start_node_label,
end_node_label,
rel_type,
relationships,
upsert_nodes=True,
start_node_id_key="id",
end_node_id_key="id",
):
"""
Insert or update multiple relationships.
@ -151,9 +168,16 @@ class GraphStore(ABC):
pass
@abstractmethod
def delete_relationship(self, start_node_label, start_node_id_value,
end_node_label, end_node_id_value,
rel_type, start_node_id_key="id", end_node_id_key="id"):
def delete_relationship(
self,
start_node_label,
start_node_id_value,
end_node_label,
end_node_id_value,
rel_type,
start_node_id_key="id",
end_node_id_key="id",
):
"""
Delete a specified relationship.
@ -169,9 +193,16 @@ class GraphStore(ABC):
pass
@abstractmethod
def delete_relationships(self, start_node_label, start_node_id_values,
end_node_label, end_node_id_values, rel_type,
start_node_id_key="id", end_node_id_key="id"):
def delete_relationships(
self,
start_node_label,
start_node_id_values,
end_node_label,
end_node_id_values,
rel_type,
start_node_id_key="id",
end_node_id_key="id",
):
"""
Delete multiple relationships.
@ -211,9 +242,16 @@ class GraphStore(ABC):
pass
@abstractmethod
def create_vector_index(self, label, property_key, index_name=None,
vector_dimensions=768, metric_type="cosine",
hnsw_m=None, hnsw_ef_construction=None):
def create_vector_index(
self,
label,
property_key,
index_name=None,
vector_dimensions=768,
metric_type="cosine",
hnsw_m=None,
hnsw_ef_construction=None,
):
"""
Create a vector index.
@ -239,7 +277,9 @@ class GraphStore(ABC):
pass
@abstractmethod
def text_search(self, query_string, label_constraints=None, topk=10, index_name=None):
def text_search(
self, query_string, label_constraints=None, topk=10, index_name=None
):
"""
Perform a text search.
@ -255,7 +295,15 @@ class GraphStore(ABC):
pass
@abstractmethod
def vector_search(self, label, property_key, query_text_or_vector, topk=10, index_name=None, ef_search=None):
def vector_search(
self,
label,
property_key,
query_text_or_vector,
topk=10,
index_name=None,
ef_search=None,
):
"""
Perform a vector search.

View File

@ -10,7 +10,6 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import os
import re
import threading
import time
@ -25,18 +24,20 @@ from knext.schema.model.base import IndexTypeEnum
logger = logging.getLogger(__name__)
class SingletonMeta(ABCMeta):
"""
Thread-safe Singleton metaclass
"""
_instances = {}
_lock = threading.Lock()
def __call__(cls, *args, **kwargs):
uri = kwargs.get('uri')
user = kwargs.get('user')
password = kwargs.get('password')
database = kwargs.get('database', 'neo4j')
uri = kwargs.get("uri")
user = kwargs.get("user")
password = kwargs.get("password")
database = kwargs.get("database", "neo4j")
key = (cls, uri, user, password, database)
with cls._lock:
@ -46,12 +47,19 @@ class SingletonMeta(ABCMeta):
class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def __init__(self, uri, user, password, database="neo4j", init_type="write", interval_minutes=10):
def __init__(
self,
uri,
user,
password,
database="neo4j",
init_type="write",
interval_minutes=10,
):
self._driver = GraphDatabase.driver(uri, auth=(user, password))
logger.info(f"init Neo4jClient uri: {uri} database: {database}")
self._database = database
self._lucene_special_chars = "\\+-!():^[]\"{}~*?|&/"
self._lucene_special_chars = '\\+-!():^[]"{}~*?|&/'
self._lucene_pattern = self._get_lucene_pattern()
self._simple_ident = "[A-Za-z_][A-Za-z0-9_]*"
self._simple_ident_pattern = re.compile(self._simple_ident)
@ -71,14 +79,16 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
self._driver.close()
def schedule_constraint(self, interval_minutes):
def job():
try:
self._labels = self._create_unique_constraint()
self._update_pagerank_graph()
except Exception as e:
import traceback
logger.error(f"Error run scheduled job: {traceback.format_exc()}")
logger.error(
f"Error run scheduled job, info: {e},\ntraceback:\n {traceback.format_exc()}"
)
def run_scheduled_tasks():
while True:
@ -116,7 +126,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
try:
result = session.run(create_constraint_query)
result.consume()
logger.debug(f"Unique constraint created for constraint_name: {constraint_name}")
logger.debug(
f"Unique constraint created for constraint_name: {constraint_name}"
)
except Exception as e:
logger.debug(f"warn creating constraint for {constraint_name}: {e}")
self._create_index_constraint(self, label, session)
@ -186,7 +198,12 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
label_property_keys = {}
for property_key in properties:
index_type = properties[property_key].index_type
if property_key == "name" or index_type and index_type in (IndexTypeEnum.Text, IndexTypeEnum.TextAndVector):
if (
property_key == "name"
or index_type
and index_type
in (IndexTypeEnum.Text, IndexTypeEnum.TextAndVector)
):
label_property_keys[property_key] = True
if label_property_keys:
labels[label] = True
@ -199,9 +216,13 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
if label not in self._labels:
self._create_unique_index_constraint(self, label, session)
try:
return session.execute_write(self._upsert_node, self, label, id_key, properties, extra_labels)
return session.execute_write(
self._upsert_node, self, label, id_key, properties, extra_labels
)
except Exception as e:
logger.error(f"upsert_node label:{label} properties:{properties} Exception: {e}")
logger.error(
f"upsert_node label:{label} properties:{properties} Exception: {e}"
)
return None
@staticmethod
@ -209,23 +230,36 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
if not label:
logger.warning("label cannot be None or empty strings")
return None
query = (f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: $properties.{self._escape_neo4j(id_key)}}}) "
"SET n += $properties ")
query = (
f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: $properties.{self._escape_neo4j(id_key)}}}) "
"SET n += $properties "
)
if extra_labels:
query += f", n:{':'.join(self._escape_neo4j(extra_label) for extra_label in extra_labels)} "
query += "RETURN n"
result = tx.run(query, properties=properties)
return result.single()[0]
def upsert_nodes(self, label, properties_list, id_key="id", extra_labels=("Entity",)):
def upsert_nodes(
self, label, properties_list, id_key="id", extra_labels=("Entity",)
):
self._preprocess_node_properties_list(label, properties_list, extra_labels)
with self._driver.session(database=self._database) as session:
if label not in self._labels:
self._create_unique_index_constraint(self, label, session)
try:
return session.execute_write(self._upsert_nodes, self, label, properties_list, id_key, extra_labels)
return session.execute_write(
self._upsert_nodes,
self,
label,
properties_list,
id_key,
extra_labels,
)
except Exception as e:
logger.error(f"upsert_nodes label:{label} properties:{properties_list} Exception: {e}")
logger.error(
f"upsert_nodes label:{label} properties:{properties_list} Exception: {e}"
)
return None
@staticmethod
@ -233,14 +267,16 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
if not label:
logger.warning("label cannot be None or empty strings")
return None
query = ("UNWIND $properties_list AS properties "
f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: properties.{self._escape_neo4j(id_key)}}}) "
"SET n += properties ")
query = (
"UNWIND $properties_list AS properties "
f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: properties.{self._escape_neo4j(id_key)}}}) "
"SET n += properties "
)
if extra_labels:
query += f", n:{':'.join(self._escape_neo4j(extra_label) for extra_label in extra_labels)} "
query += "RETURN n"
result = tx.run(query, properties_list=properties_list)
return [record['n'] for record in result]
return [record["n"] for record in result]
def _get_embedding_vector(self, properties, vector_field):
for property_key, property_value in properties.items():
@ -256,7 +292,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
vector = self.vectorizer.vectorize(property_value)
return vector
except Exception as e:
logger.info(f"An error occurred while vectorizing property {property_key!r}: {e}")
logger.info(
f"An error occurred while vectorizing property {property_key!r}: {e}"
)
return None
return None
@ -287,7 +325,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
return
class EmbeddingVectorPlaceholder(object):
def __init__(self, number, properties, vector_field, property_key, property_value):
def __init__(
self, number, properties, vector_field, property_key, property_value
):
self._number = number
self._properties = properties
self._vector_field = vector_field
@ -317,7 +357,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
message = f"property {property_key!r} must be string to generate embedding vector"
raise RuntimeError(message)
num = len(self._placeholders)
placeholder = EmbeddingVectorPlaceholder(num, properties, vector_field, property_key, property_value)
placeholder = EmbeddingVectorPlaceholder(
num, properties, vector_field, property_key, property_value
)
self._placeholders.append(placeholder)
return placeholder
return None
@ -364,7 +406,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
for vector_field in vec_meta[label]:
if vector_field in properties:
continue
placeholder = manager.get_placeholder(self, properties, vector_field)
placeholder = manager.get_placeholder(
self, properties, vector_field
)
if placeholder is not None:
properties[vector_field] = placeholder
manager.batch_vectorize(self._vectorizer)
@ -406,25 +450,58 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
query = f"UNWIND $id_values AS id_value MATCH (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: id_value}}) DETACH DELETE n"
tx.run(query, id_values=id_values)
def upsert_relationship(self, start_node_label, start_node_id_value,
end_node_label, end_node_id_value, rel_type,
properties, upsert_nodes=True, start_node_id_key="id", end_node_id_key="id"):
def upsert_relationship(
self,
start_node_label,
start_node_id_value,
end_node_label,
end_node_id_value,
rel_type,
properties,
upsert_nodes=True,
start_node_id_key="id",
end_node_id_key="id",
):
rel_type = self._escape_neo4j(rel_type)
with self._driver.session(database=self._database) as session:
try:
return session.execute_write(self._upsert_relationship, self, start_node_label, start_node_id_key,
start_node_id_value, end_node_label, end_node_id_key,
end_node_id_value, rel_type, properties, upsert_nodes)
return session.execute_write(
self._upsert_relationship,
self,
start_node_label,
start_node_id_key,
start_node_id_value,
end_node_label,
end_node_id_key,
end_node_id_value,
rel_type,
properties,
upsert_nodes,
)
except Exception as e:
logger.error(f"upsert_relationship rel_type:{rel_type} properties:{properties} Exception: {e}")
logger.error(
f"upsert_relationship rel_type:{rel_type} properties:{properties} Exception: {e}"
)
return None
@staticmethod
def _upsert_relationship(tx, self, start_node_label, start_node_id_key, start_node_id_value,
end_node_label, end_node_id_key, end_node_id_value,
rel_type, properties, upsert_nodes):
def _upsert_relationship(
tx,
self,
start_node_label,
start_node_id_key,
start_node_id_value,
end_node_label,
end_node_id_key,
end_node_id_value,
rel_type,
properties,
upsert_nodes,
):
if not start_node_label or not end_node_label or not rel_type:
logger.warning("start_node_label, end_node_label, and rel_type cannot be None or empty strings")
logger.warning(
"start_node_label, end_node_label, and rel_type cannot be None or empty strings"
)
return None
if upsert_nodes:
query = (
@ -438,25 +515,59 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: $end_node_id_value}}) "
f"MERGE (a)-[r:{self._escape_neo4j(rel_type)}]->(b) SET r += $properties RETURN r"
)
result = tx.run(query, start_node_id_value=start_node_id_value,
end_node_id_value=end_node_id_value, properties=properties)
result = tx.run(
query,
start_node_id_value=start_node_id_value,
end_node_id_value=end_node_id_value,
properties=properties,
)
return result.single()
def upsert_relationships(self, start_node_label, end_node_label, rel_type, relations,
upsert_nodes=True, start_node_id_key="id", end_node_id_key="id"):
def upsert_relationships(
self,
start_node_label,
end_node_label,
rel_type,
relations,
upsert_nodes=True,
start_node_id_key="id",
end_node_id_key="id",
):
with self._driver.session(database=self._database) as session:
try:
return session.execute_write(self._upsert_relationships, self, relations, start_node_label,
start_node_id_key, end_node_label, end_node_id_key, rel_type, upsert_nodes)
return session.execute_write(
self._upsert_relationships,
self,
relations,
start_node_label,
start_node_id_key,
end_node_label,
end_node_id_key,
rel_type,
upsert_nodes,
)
except Exception as e:
logger.error(f"upsert_relationships rel_type:{rel_type} relations:{relations} Exception: {e}")
logger.error(
f"upsert_relationships rel_type:{rel_type} relations:{relations} Exception: {e}"
)
return None
@staticmethod
def _upsert_relationships(tx, self, relations, start_node_label, start_node_id_key,
end_node_label, end_node_id_key, rel_type, upsert_nodes):
def _upsert_relationships(
tx,
self,
relations,
start_node_label,
start_node_id_key,
end_node_label,
end_node_id_key,
rel_type,
upsert_nodes,
):
if not start_node_label or not end_node_label or not rel_type:
logger.warning("start_node_label, end_node_label, and rel_type cannot be None or empty strings")
logger.warning(
"start_node_label, end_node_label, and rel_type cannot be None or empty strings"
)
return None
if upsert_nodes:
query = (
@ -473,51 +584,111 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
f"MERGE (a)-[r:{self._escape_neo4j(rel_type)}]->(b) SET r += relationship.properties RETURN r"
)
result = tx.run(query, relations=relations,
start_node_label=start_node_label, start_node_id_key=start_node_id_key,
end_node_label=end_node_label, end_node_id_key=end_node_id_key,
rel_type=rel_type)
return [record['r'] for record in result]
result = tx.run(
query,
relations=relations,
start_node_label=start_node_label,
start_node_id_key=start_node_id_key,
end_node_label=end_node_label,
end_node_id_key=end_node_id_key,
rel_type=rel_type,
)
return [record["r"] for record in result]
def delete_relationship(self, start_node_label, start_node_id_value,
end_node_label, end_node_id_value, rel_type,
start_node_id_key="id", end_node_id_key="id"):
def delete_relationship(
self,
start_node_label,
start_node_id_value,
end_node_label,
end_node_id_value,
rel_type,
start_node_id_key="id",
end_node_id_key="id",
):
with self._driver.session(database=self._database) as session:
try:
session.execute_write(self._delete_relationship, self, start_node_label, start_node_id_key,
start_node_id_value, end_node_label, end_node_id_key,
end_node_id_value, rel_type)
session.execute_write(
self._delete_relationship,
self,
start_node_label,
start_node_id_key,
start_node_id_value,
end_node_label,
end_node_id_key,
end_node_id_value,
rel_type,
)
except Exception as e:
logger.error(f"delete_relationship rel_type:{rel_type} Exception: {e}")
@staticmethod
def _delete_relationship(tx, self, start_node_label, start_node_id_key, start_node_id_value,
end_node_label, end_node_id_key, end_node_id_value, rel_type):
def _delete_relationship(
tx,
self,
start_node_label,
start_node_id_key,
start_node_id_value,
end_node_label,
end_node_id_key,
end_node_id_value,
rel_type,
):
query = (
f"MATCH (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: $start_node_id_value}})-[r:{self._escape_neo4j(rel_type)}]->"
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: $end_node_id_value}}) DELETE r"
)
tx.run(query, start_node_id_value=start_node_id_value, end_node_id_value=end_node_id_value)
tx.run(
query,
start_node_id_value=start_node_id_value,
end_node_id_value=end_node_id_value,
)
def delete_relationships(self, start_node_label, start_node_id_values,
end_node_label, end_node_id_values, rel_type,
start_node_id_key="id", end_node_id_key="id"):
def delete_relationships(
self,
start_node_label,
start_node_id_values,
end_node_label,
end_node_id_values,
rel_type,
start_node_id_key="id",
end_node_id_key="id",
):
with self._driver.session(database=self._database) as session:
session.execute_write(self._delete_relationships, self,
start_node_label, start_node_id_key, start_node_id_values,
end_node_label, end_node_id_key, end_node_id_values, rel_type)
session.execute_write(
self._delete_relationships,
self,
start_node_label,
start_node_id_key,
start_node_id_values,
end_node_label,
end_node_id_key,
end_node_id_values,
rel_type,
)
@staticmethod
def _delete_relationships(tx, self, start_node_label, start_node_id_key, start_node_id_values,
end_node_label, end_node_id_key, end_node_id_values, rel_type):
def _delete_relationships(
tx,
self,
start_node_label,
start_node_id_key,
start_node_id_values,
end_node_label,
end_node_id_key,
end_node_id_values,
rel_type,
):
query = (
"UNWIND $start_node_id_values AS start_node_id_value "
"UNWIND $end_node_id_values AS end_node_id_value "
f"MATCH (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: start_node_id_value}})-[r:{self._escape_neo4j(rel_type)}]->"
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: end_node_id_value}}) DELETE r"
)
tx.run(query, start_node_id_values=start_node_id_values, end_node_id_values=end_node_id_values)
tx.run(
query,
start_node_id_values=start_node_id_values,
end_node_id_values=end_node_id_values,
)
def _get_lucene_pattern(self):
string = re.escape(self._lucene_special_chars)
@ -539,7 +710,7 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
for ch in string:
data = ch.encode("utf-16-le")
for i in range(0, len(data), 2):
value = int.from_bytes(data[i:i+2], "little")
value = int.from_bytes(data[i : i + 2], "little")
result.append(value)
return tuple(result)
@ -562,6 +733,7 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def _to_snake_case(self, name):
import re
words = re.findall("[A-Za-z][a-z0-9]*", name)
result = "_".join(words).lower()
return result
@ -578,7 +750,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def create_index(self, label, property_key, index_name=None):
with self._driver.session(database=self._database) as session:
session.execute_write(self._create_index, self, label, property_key, index_name)
session.execute_write(
self._create_index, self, label, property_key, index_name
)
@staticmethod
def _create_index(tx, self, label, property_key, index_name):
@ -596,50 +770,87 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
if index_name is None:
index_name = "_default_text_index"
label_spec = "|".join(self._escape_neo4j(label) for label in labels)
property_spec = ", ".join(f"n.{self._escape_neo4j(key)}" for key in property_keys)
property_spec = ", ".join(
f"n.{self._escape_neo4j(key)}" for key in property_keys
)
query = (
f"CREATE FULLTEXT INDEX {self._escape_neo4j(index_name)} IF NOT EXISTS "
f"FOR (n:{label_spec}) ON EACH [{property_spec}]"
)
def do_create_text_index(tx):
tx.run(query)
with self._driver.session(database=self._database) as session:
session.execute_write(do_create_text_index)
return index_name
def create_vector_index(self, label, property_key, index_name=None,
vector_dimensions=768, metric_type="cosine",
hnsw_m=None, hnsw_ef_construction=None):
def create_vector_index(
self,
label,
property_key,
index_name=None,
vector_dimensions=768,
metric_type="cosine",
hnsw_m=None,
hnsw_ef_construction=None,
):
if index_name is None:
index_name = self._create_vector_index_name(label, property_key)
if not property_key.lower().endswith("vector"):
property_key = self._create_vector_field_name(property_key)
with self._driver.session(database=self._database) as session:
session.execute_write(self._create_vector_index, self, label, property_key, index_name,
vector_dimensions, metric_type, hnsw_m, hnsw_ef_construction)
session.execute_write(
self._create_vector_index,
self,
label,
property_key,
index_name,
vector_dimensions,
metric_type,
hnsw_m,
hnsw_ef_construction,
)
self.refresh_vector_index_meta(force=True)
return index_name
@staticmethod
def _create_vector_index(tx, self, label, property_key, index_name, vector_dimensions, metric_type, hnsw_m, hnsw_ef_construction):
def _create_vector_index(
tx,
self,
label,
property_key,
index_name,
vector_dimensions,
metric_type,
hnsw_m,
hnsw_ef_construction,
):
query = (
f"CREATE VECTOR INDEX {self._escape_neo4j(index_name)} IF NOT EXISTS FOR (n:{self._escape_neo4j(label)}) ON (n.{self._escape_neo4j(property_key)}) "
"OPTIONS { indexConfig: {"
" `vector.dimensions`: $vector_dimensions,"
" `vector.similarity_function`: $metric_type"
"OPTIONS { indexConfig: {"
" `vector.dimensions`: $vector_dimensions,"
" `vector.similarity_function`: $metric_type"
)
if hnsw_m is not None:
query += ", `vector.hnsw.m`: $hnsw_m"
if hnsw_ef_construction is not None:
query += ", `vector.hnsw.ef_construction`: $hnsw_ef_construction"
query += "}}"
tx.run(query, vector_dimensions=vector_dimensions, metric_type=metric_type,
hnsw_m=hnsw_m, hnsw_ef_construction=hnsw_ef_construction)
tx.run(
query,
vector_dimensions=vector_dimensions,
metric_type=metric_type,
hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction,
)
def refresh_vector_index_meta(self, force=False):
import time
if not force and time.time() - self._vec_meta_ts < self._vec_meta_timeout:
return
def do_refresh_vector_index_meta(tx):
query = "SHOW VECTOR INDEX"
res = tx.run(query)
@ -647,14 +858,17 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
meta = dict()
for record in data:
if record["entityType"] == "NODE":
label, = record["labelsOrTypes"]
vector_field, = record["properties"]
if vector_field.startswith("_") and vector_field.endswith("_vector"):
(label,) = record["labelsOrTypes"]
(vector_field,) = record["properties"]
if vector_field.startswith("_") and vector_field.endswith(
"_vector"
):
if label not in meta:
meta[label] = []
meta[label].append(vector_field)
self._vec_meta = meta
self._vec_meta_ts = time.time()
with self._driver.session(database=self._database) as session:
session.execute_read(do_refresh_vector_index_meta)
@ -678,7 +892,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def vectorizer(self, value):
self._vectorizer = value
def text_search(self, query_string, label_constraints=None, topk=10, index_name=None):
def text_search(
self, query_string, label_constraints=None, topk=10, index_name=None
):
if index_name is None:
index_name = "_default_text_index"
if label_constraints is None:
@ -686,31 +902,48 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
elif isinstance(label_constraints, str):
label_constraints = self._escape_neo4j(label_constraints)
elif isinstance(label_constraints, (list, tuple)):
label_constraints = "|".join(self._escape_neo4j(label_constraint) for label_constraint in label_constraints)
label_constraints = "|".join(
self._escape_neo4j(label_constraint)
for label_constraint in label_constraints
)
else:
message = f"invalid label_constraints: {label_constraints!r}"
raise RuntimeError(message)
if label_constraints is None:
query = ("CALL db.index.fulltext.queryNodes($index_name, $query_string) "
"YIELD node AS node, score "
"RETURN node, score")
query = (
"CALL db.index.fulltext.queryNodes($index_name, $query_string) "
"YIELD node AS node, score "
"RETURN node, score"
)
else:
query = ("CALL db.index.fulltext.queryNodes($index_name, $query_string) "
"YIELD node AS node, score "
f"WHERE (node:{label_constraints}) "
"RETURN node, score")
query = (
"CALL db.index.fulltext.queryNodes($index_name, $query_string) "
"YIELD node AS node, score "
f"WHERE (node:{label_constraints}) "
"RETURN node, score"
)
query += " LIMIT $topk"
query_string = self._make_lucene_query(query_string)
def do_text_search(tx):
res = tx.run(query, query_string=query_string, topk=topk, index_name=index_name)
res = tx.run(
query, query_string=query_string, topk=topk, index_name=index_name
)
data = res.data()
return data
with self._driver.session(database=self._database) as session:
return session.execute_read(do_text_search)
def vector_search(self, label, property_key, query_text_or_vector, topk=10, index_name=None, ef_search=None):
def vector_search(
self,
label,
property_key,
query_text_or_vector,
topk=10,
index_name=None,
ef_search=None,
):
if ef_search is not None:
if ef_search < topk:
message = f"ef_search must be greater than or equal to topk; {ef_search!r} is invalid"
@ -719,13 +952,17 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
if index_name is None:
vec_meta = self._vec_meta
if label not in vec_meta:
logger.warning(f"vector index not defined for label, return empty. label: {label}, "
f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}.")
logger.warning(
f"vector index not defined for label, return empty. label: {label}, "
f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}."
)
return []
vector_field = self._create_vector_field_name(property_key)
if vector_field not in vec_meta[label]:
logger.warning(f"vector index not defined for field, return empty. label: {label}, "
f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}.")
logger.warning(
f"vector index not defined for field, return empty. label: {label}, "
f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}."
)
return []
if index_name is None:
index_name = self._create_vector_index_name(label, property_key)
@ -736,16 +973,27 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def do_vector_search(tx):
if ef_search is not None:
query = ("CALL db.index.vector.queryNodes($index_name, $ef_search, $query_vector) "
"YIELD node, score "
"RETURN node, score, labels(node) as __labels__"
f"LIMIT {topk}")
res = tx.run(query, query_vector=query_vector, ef_search=ef_search, index_name=index_name)
query = (
"CALL db.index.vector.queryNodes($index_name, $ef_search, $query_vector) "
"YIELD node, score "
"RETURN node, score, labels(node) as __labels__"
f"LIMIT {topk}"
)
res = tx.run(
query,
query_vector=query_vector,
ef_search=ef_search,
index_name=index_name,
)
else:
query = ("CALL db.index.vector.queryNodes($index_name, $topk, $query_vector) "
"YIELD node, score "
"RETURN node, score, labels(node) as __labels__")
res = tx.run(query, query_vector=query_vector, topk=topk, index_name=index_name)
query = (
"CALL db.index.vector.queryNodes($index_name, $topk, $query_vector) "
"YIELD node, score "
"RETURN node, score, labels(node) as __labels__"
)
res = tx.run(
query, query_vector=query_vector, topk=topk, index_name=index_name
)
data = res.data()
for record in data:
record["node"]["__labels__"] = record["__labels__"]
@ -757,41 +1005,59 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def _create_all_graph(self, graph_name):
with self._driver.session(database=self._database) as session:
logger.debug(f"create pagerank graph graph_name{graph_name} database{self._database}")
result = session.run(f"""
logger.debug(
f"create pagerank graph graph_name{graph_name} database{self._database}"
)
result = session.run(
f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE exists
CALL gds.graph.drop('{graph_name}') YIELD graphName
RETURN graphName
""")
"""
)
summary = result.consume()
logger.debug(f"create pagerank graph exists graph_name{graph_name} database{self._database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}")
logger.debug(
f"create pagerank graph exists graph_name{graph_name} database{self._database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}"
)
result = session.run(f"""
result = session.run(
f"""
CALL gds.graph.project('{graph_name}','*','*')
YIELD graphName, nodeCount AS nodes, relationshipCount AS rels
RETURN graphName, nodes, rels
""")
"""
)
summary = result.consume()
logger.debug(f"create pagerank graph graph_name{graph_name} database{self._database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}")
logger.debug(
f"create pagerank graph graph_name{graph_name} database{self._database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}"
)
def _drop_all_graph(self, graph_name):
with self._driver.session(database=self._database) as session:
logger.debug(f"drop pagerank graph graph_name{graph_name} database{self._database}")
result = session.run(f"""
logger.debug(
f"drop pagerank graph graph_name{graph_name} database{self._database}"
)
result = session.run(
f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE exists
CALL gds.graph.drop('{graph_name}') YIELD graphName
RETURN graphName
""")
"""
)
result.consume()
logger.debug(f"drop pagerank graph graph_name{graph_name} database{self._database} succeed")
logger.debug(
f"drop pagerank graph graph_name{graph_name} database{self._database} succeed"
)
def execute_pagerank(self, iterations=20, damping_factor=0.85):
with self._driver.session(database=self._database) as session:
return session.execute_write(self._execute_pagerank, iterations, damping_factor)
return session.execute_write(
self._execute_pagerank, iterations, damping_factor
)
@staticmethod
def _execute_pagerank(tx, iterations, damping_factor):
@ -809,7 +1075,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
with self._driver.session(database=self._database) as session:
all_graph = self._allGraph
self._exists_all_graph(session, all_graph)
data = session.execute_write(self._get_pagerank_scores, self, all_graph, start_nodes, target_type)
data = session.execute_write(
self._get_pagerank_scores, self, all_graph, start_nodes, target_type
)
return data
@staticmethod
@ -817,13 +1085,15 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
match_clauses = []
match_identify = []
for index, node in enumerate(start_nodes):
node_type, node_name = node['type'], node['name']
node_type, node_name = node["type"], node["name"]
node_identify = f"node_{index}"
match_clauses.append(f"MATCH ({node_identify}:{self._escape_neo4j(node_type)} {{name: '{escape_single_quotes(node_name)}'}})")
match_clauses.append(
f"MATCH ({node_identify}:{self._escape_neo4j(node_type)} {{name: '{escape_single_quotes(node_name)}'}})"
)
match_identify.append(node_identify)
match_query = ' '.join(match_clauses)
match_identify_str = ', '.join(match_identify)
match_query = " ".join(match_clauses)
match_identify_str = ", ".join(match_identify)
pagerank_query = f"""
{match_query}
@ -845,16 +1115,20 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def _exists_all_graph(session, graph_name):
try:
logger.debug(f"exists pagerank graph graph_name{graph_name}")
result = session.run(f"""
result = session.run(
f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE NOT exists
CALL gds.graph.project('{graph_name}','*','*')
YIELD graphName, nodeCount AS nodes, relationshipCount AS rels
RETURN graphName, nodes, rels
""")
"""
)
summary = result.consume()
logger.debug(f"exists pagerank graph graph_name{graph_name} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}")
logger.debug(
f"exists pagerank graph graph_name{graph_name} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}"
)
except Exception as e:
logger.debug(f"Error exists pagerank graph {graph_name}: {e}")
@ -873,18 +1147,26 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def create_database(self, database):
with self._driver.session(database=self._database) as session:
database = database.lower()
result = session.run(f"CREATE DATABASE {self._escape_neo4j(database)} IF NOT EXISTS")
result = session.run(
f"CREATE DATABASE {self._escape_neo4j(database)} IF NOT EXISTS"
)
summary = result.consume()
logger.info(f"create_database {database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}")
logger.info(
f"create_database {database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}"
)
def delete_all_data(self, database):
if self._database != database:
raise ValueError(f"Error: Current database ({self._database}) is not the same as the target database ({database}).")
raise ValueError(
f"Error: Current database ({self._database}) is not the same as the target database ({database})."
)
with self._driver.session(database=database) as session:
while True:
result = session.run("MATCH (n) WITH n LIMIT 100000 DETACH DELETE n RETURN count(*)")
result = session.run(
"MATCH (n) WITH n LIMIT 100000 DETACH DELETE n RETURN count(*)"
)
count = result.single()[0]
logger.info(f"Deleted {count} nodes in this batch.")
if count == 0:
@ -893,7 +1175,9 @@ class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def run_cypher_query(self, database, query, parameters=None):
if database and self._database != database:
raise ValueError(f"Current database ({self._database}) is not the same as the target database ({database}).")
raise ValueError(
f"Current database ({self._database}) is not the same as the target database ({database})."
)
with self._driver.session(database=database) as session:
result = session.run(query, parameters)

View File

@ -35,4 +35,6 @@ from kag.common.graphstore.rest.models.delete_vertex_request import DeleteVertex
from kag.common.graphstore.rest.models.edge_record_instance import EdgeRecordInstance
from kag.common.graphstore.rest.models.upsert_edge_request import UpsertEdgeRequest
from kag.common.graphstore.rest.models.upsert_vertex_request import UpsertVertexRequest
from kag.common.graphstore.rest.models.vertex_record_instance import VertexRecordInstance
from kag.common.graphstore.rest.models.vertex_record_instance import (
VertexRecordInstance,
)

View File

@ -18,10 +18,7 @@ import re # noqa: F401
import six
from kag.common.rest.api_client import ApiClient
from kag.common.rest.exceptions import ( # noqa: F401
ApiTypeError,
ApiValueError
)
from kag.common.rest.exceptions import ApiTypeError, ApiValueError # noqa: F401
class GraphApi(object):
@ -57,7 +54,7 @@ class GraphApi(object):
If the method is called asynchronously,
returns the request thread.
"""
kwargs['_return_http_data_only'] = True
kwargs["_return_http_data_only"] = True
return self.graph_delete_edge_post_with_http_info(**kwargs) # noqa: E501
def graph_delete_edge_post_with_http_info(self, **kwargs): # noqa: E501
@ -86,26 +83,24 @@ class GraphApi(object):
local_var_params = locals()
all_params = [
'delete_edge_request'
]
all_params = ["delete_edge_request"]
all_params.extend(
[
'async_req',
'_return_http_data_only',
'_preload_content',
'_request_timeout'
"async_req",
"_return_http_data_only",
"_preload_content",
"_request_timeout",
]
)
for key, val in six.iteritems(local_var_params['kwargs']):
for key, val in six.iteritems(local_var_params["kwargs"]):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_delete_edge_post" % key
)
local_var_params[key] = val
del local_var_params['kwargs']
del local_var_params["kwargs"]
collection_formats = {}
@ -119,34 +114,42 @@ class GraphApi(object):
local_var_files = {}
body_params = None
if 'delete_edge_request' in local_var_params:
body_params = local_var_params['delete_edge_request']
if "delete_edge_request" in local_var_params:
body_params = local_var_params["delete_edge_request"]
# HTTP header `Accept`
header_params['Accept'] = self.api_client.select_header_accept(
['application/json']) # noqa: E501
header_params["Accept"] = self.api_client.select_header_accept(
["application/json"]
) # noqa: E501
# HTTP header `Content-Type`
header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
['application/json']) # noqa: E501
header_params[
"Content-Type"
] = self.api_client.select_header_content_type( # noqa: E501
["application/json"]
) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
'/graph/deleteEdge', 'POST',
"/graph/deleteEdge",
"POST",
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
response_type='object', # noqa: E501
response_type="object", # noqa: E501
auth_settings=auth_settings,
async_req=local_var_params.get('async_req'),
_return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
_preload_content=local_var_params.get('_preload_content', True),
_request_timeout=local_var_params.get('_request_timeout'),
collection_formats=collection_formats)
async_req=local_var_params.get("async_req"),
_return_http_data_only=local_var_params.get(
"_return_http_data_only"
), # noqa: E501
_preload_content=local_var_params.get("_preload_content", True),
_request_timeout=local_var_params.get("_request_timeout"),
collection_formats=collection_formats,
)
def graph_delete_vertex_post(self, **kwargs): # noqa: E501
"""delete_vertex # noqa: E501
@ -169,7 +172,7 @@ class GraphApi(object):
If the method is called asynchronously,
returns the request thread.
"""
kwargs['_return_http_data_only'] = True
kwargs["_return_http_data_only"] = True
return self.graph_delete_vertex_post_with_http_info(**kwargs) # noqa: E501
def graph_delete_vertex_post_with_http_info(self, **kwargs): # noqa: E501
@ -198,26 +201,24 @@ class GraphApi(object):
local_var_params = locals()
all_params = [
'delete_vertex_request'
]
all_params = ["delete_vertex_request"]
all_params.extend(
[
'async_req',
'_return_http_data_only',
'_preload_content',
'_request_timeout'
"async_req",
"_return_http_data_only",
"_preload_content",
"_request_timeout",
]
)
for key, val in six.iteritems(local_var_params['kwargs']):
for key, val in six.iteritems(local_var_params["kwargs"]):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_delete_vertex_post" % key
)
local_var_params[key] = val
del local_var_params['kwargs']
del local_var_params["kwargs"]
collection_formats = {}
@ -231,34 +232,42 @@ class GraphApi(object):
local_var_files = {}
body_params = None
if 'delete_vertex_request' in local_var_params:
body_params = local_var_params['delete_vertex_request']
if "delete_vertex_request" in local_var_params:
body_params = local_var_params["delete_vertex_request"]
# HTTP header `Accept`
header_params['Accept'] = self.api_client.select_header_accept(
['application/json']) # noqa: E501
header_params["Accept"] = self.api_client.select_header_accept(
["application/json"]
) # noqa: E501
# HTTP header `Content-Type`
header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
['application/json']) # noqa: E501
header_params[
"Content-Type"
] = self.api_client.select_header_content_type( # noqa: E501
["application/json"]
) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
'/graph/deleteVertex', 'POST',
"/graph/deleteVertex",
"POST",
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
response_type='object', # noqa: E501
response_type="object", # noqa: E501
auth_settings=auth_settings,
async_req=local_var_params.get('async_req'),
_return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
_preload_content=local_var_params.get('_preload_content', True),
_request_timeout=local_var_params.get('_request_timeout'),
collection_formats=collection_formats)
async_req=local_var_params.get("async_req"),
_return_http_data_only=local_var_params.get(
"_return_http_data_only"
), # noqa: E501
_preload_content=local_var_params.get("_preload_content", True),
_request_timeout=local_var_params.get("_request_timeout"),
collection_formats=collection_formats,
)
def graph_upsert_edge_post(self, **kwargs): # noqa: E501
"""upsert_edge # noqa: E501
@ -281,7 +290,7 @@ class GraphApi(object):
If the method is called asynchronously,
returns the request thread.
"""
kwargs['_return_http_data_only'] = True
kwargs["_return_http_data_only"] = True
return self.graph_upsert_edge_post_with_http_info(**kwargs) # noqa: E501
def graph_upsert_edge_post_with_http_info(self, **kwargs): # noqa: E501
@ -310,26 +319,24 @@ class GraphApi(object):
local_var_params = locals()
all_params = [
'upsert_edge_request'
]
all_params = ["upsert_edge_request"]
all_params.extend(
[
'async_req',
'_return_http_data_only',
'_preload_content',
'_request_timeout'
"async_req",
"_return_http_data_only",
"_preload_content",
"_request_timeout",
]
)
for key, val in six.iteritems(local_var_params['kwargs']):
for key, val in six.iteritems(local_var_params["kwargs"]):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_upsert_edge_post" % key
)
local_var_params[key] = val
del local_var_params['kwargs']
del local_var_params["kwargs"]
collection_formats = {}
@ -343,34 +350,42 @@ class GraphApi(object):
local_var_files = {}
body_params = None
if 'upsert_edge_request' in local_var_params:
body_params = local_var_params['upsert_edge_request']
if "upsert_edge_request" in local_var_params:
body_params = local_var_params["upsert_edge_request"]
# HTTP header `Accept`
header_params['Accept'] = self.api_client.select_header_accept(
['application/json']) # noqa: E501
header_params["Accept"] = self.api_client.select_header_accept(
["application/json"]
) # noqa: E501
# HTTP header `Content-Type`
header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
['application/json']) # noqa: E501
header_params[
"Content-Type"
] = self.api_client.select_header_content_type( # noqa: E501
["application/json"]
) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
'/graph/upsertEdge', 'POST',
"/graph/upsertEdge",
"POST",
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
response_type='object', # noqa: E501
response_type="object", # noqa: E501
auth_settings=auth_settings,
async_req=local_var_params.get('async_req'),
_return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
_preload_content=local_var_params.get('_preload_content', True),
_request_timeout=local_var_params.get('_request_timeout'),
collection_formats=collection_formats)
async_req=local_var_params.get("async_req"),
_return_http_data_only=local_var_params.get(
"_return_http_data_only"
), # noqa: E501
_preload_content=local_var_params.get("_preload_content", True),
_request_timeout=local_var_params.get("_request_timeout"),
collection_formats=collection_formats,
)
def graph_upsert_vertex_post(self, **kwargs): # noqa: E501
"""upsert_vertex # noqa: E501
@ -393,7 +408,7 @@ class GraphApi(object):
If the method is called asynchronously,
returns the request thread.
"""
kwargs['_return_http_data_only'] = True
kwargs["_return_http_data_only"] = True
return self.graph_upsert_vertex_post_with_http_info(**kwargs) # noqa: E501
def graph_upsert_vertex_post_with_http_info(self, **kwargs): # noqa: E501
@ -422,26 +437,24 @@ class GraphApi(object):
local_var_params = locals()
all_params = [
'upsert_vertex_request'
]
all_params = ["upsert_vertex_request"]
all_params.extend(
[
'async_req',
'_return_http_data_only',
'_preload_content',
'_request_timeout'
"async_req",
"_return_http_data_only",
"_preload_content",
"_request_timeout",
]
)
for key, val in six.iteritems(local_var_params['kwargs']):
for key, val in six.iteritems(local_var_params["kwargs"]):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_upsert_vertex_post" % key
)
local_var_params[key] = val
del local_var_params['kwargs']
del local_var_params["kwargs"]
collection_formats = {}
@ -455,31 +468,39 @@ class GraphApi(object):
local_var_files = {}
body_params = None
if 'upsert_vertex_request' in local_var_params:
body_params = local_var_params['upsert_vertex_request']
if "upsert_vertex_request" in local_var_params:
body_params = local_var_params["upsert_vertex_request"]
# HTTP header `Accept`
header_params['Accept'] = self.api_client.select_header_accept(
['application/json']) # noqa: E501
header_params["Accept"] = self.api_client.select_header_accept(
["application/json"]
) # noqa: E501
# HTTP header `Content-Type`
header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
['application/json']) # noqa: E501
header_params[
"Content-Type"
] = self.api_client.select_header_content_type( # noqa: E501
["application/json"]
) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
'/graph/upsertVertex', 'POST',
"/graph/upsertVertex",
"POST",
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
response_type='object', # noqa: E501
response_type="object", # noqa: E501
auth_settings=auth_settings,
async_req=local_var_params.get('async_req'),
_return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
_preload_content=local_var_params.get('_preload_content', True),
_request_timeout=local_var_params.get('_request_timeout'),
collection_formats=collection_formats)
async_req=local_var_params.get("async_req"),
_return_http_data_only=local_var_params.get(
"_return_http_data_only"
), # noqa: E501
_preload_content=local_var_params.get("_preload_content", True),
_request_timeout=local_var_params.get("_request_timeout"),
collection_formats=collection_formats,
)

View File

@ -16,4 +16,6 @@ from kag.common.graphstore.rest.models.delete_vertex_request import DeleteVertex
from kag.common.graphstore.rest.models.edge_record_instance import EdgeRecordInstance
from kag.common.graphstore.rest.models.upsert_edge_request import UpsertEdgeRequest
from kag.common.graphstore.rest.models.upsert_vertex_request import UpsertVertexRequest
from kag.common.graphstore.rest.models.vertex_record_instance import VertexRecordInstance
from kag.common.graphstore.rest.models.vertex_record_instance import (
VertexRecordInstance,
)

View File

@ -32,17 +32,13 @@ class DeleteEdgeRequest(object):
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'project_id': 'int',
'edges': 'list[EdgeRecordInstance]'
}
openapi_types = {"project_id": "int", "edges": "list[EdgeRecordInstance]"}
attribute_map = {
'project_id': 'projectId',
'edges': 'edges'
}
attribute_map = {"project_id": "projectId", "edges": "edges"}
def __init__(self, project_id=None, edges=None, local_vars_configuration=None): # noqa: E501
def __init__(
self, project_id=None, edges=None, local_vars_configuration=None
): # noqa: E501
"""DeleteEdgeRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
@ -73,8 +69,12 @@ class DeleteEdgeRequest(object):
:param project_id: The project_id of this DeleteEdgeRequest. # noqa: E501
:type: int
"""
if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
if (
self.local_vars_configuration.client_side_validation and project_id is None
): # noqa: E501
raise ValueError(
"Invalid value for `project_id`, must not be `None`"
) # noqa: E501
self._project_id = project_id
@ -96,8 +96,12 @@ class DeleteEdgeRequest(object):
:param edges: The edges of this DeleteEdgeRequest. # noqa: E501
:type: list[EdgeRecordInstance]
"""
if self.local_vars_configuration.client_side_validation and edges is None: # noqa: E501
raise ValueError("Invalid value for `edges`, must not be `None`") # noqa: E501
if (
self.local_vars_configuration.client_side_validation and edges is None
): # noqa: E501
raise ValueError(
"Invalid value for `edges`, must not be `None`"
) # noqa: E501
self._edges = edges
@ -108,18 +112,20 @@ class DeleteEdgeRequest(object):
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
result[attr] = list(
map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)
)
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
result[attr] = dict(
map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict")
else item,
value.items(),
)
)
else:
result[attr] = value

Some files were not shown because too many files have changed in this diff Show More