commit
179c52ec18
|
@ -48,11 +48,10 @@ RUN if [ "$SKIP_LLAMA_BUILD" = "false" ]; then \
|
|||
mkdir -p build && \
|
||||
cd build && \
|
||||
echo "Starting CMake configuration with CUDA support..." && \
|
||||
cmake -DGGML_CUDA=ON \
|
||||
cmake -DGGML_CUDA=OFF -DLLAMA_CUBLAS=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DLLAMA_NATIVE=OFF \
|
||||
-DCMAKE_CUDA_FLAGS="-Wno-deprecated-gpu-targets" \
|
||||
-DLLAMA_NATIVE=ON \
|
||||
.. && \
|
||||
echo "Starting build process (this will take several minutes)..." && \
|
||||
cmake --build . --config Release -j --verbose && \
|
||||
|
|
|
@ -216,9 +216,6 @@ export default function TrainingPage() {
|
|||
if (trainingProgress.status === 'in_progress') {
|
||||
setIsTraining(true);
|
||||
|
||||
// Create EventSource connection to get logs
|
||||
updateTrainLog();
|
||||
|
||||
if (firstLoadRef.current) {
|
||||
scrollPageToBottom();
|
||||
|
||||
|
@ -234,19 +231,15 @@ export default function TrainingPage() {
|
|||
) {
|
||||
stopPolling();
|
||||
setIsTraining(false);
|
||||
|
||||
// Keep EventSource open to preserve received logs
|
||||
// If resource cleanup is needed, EventSource could be closed here
|
||||
}
|
||||
|
||||
// Return cleanup function to ensure EventSource is closed when component unmounts or dependencies change
|
||||
return () => {
|
||||
if (cleanupEventSourceRef.current) {
|
||||
cleanupEventSourceRef.current();
|
||||
}
|
||||
};
|
||||
}, [trainingProgress]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isTraining) {
|
||||
updateTrainLog();
|
||||
}
|
||||
}, [isTraining]);
|
||||
|
||||
// Cleanup when component unmounts
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
|
@ -297,7 +290,7 @@ export default function TrainingPage() {
|
|||
|
||||
const getDetails = () => {
|
||||
// Use EventSource to get logs
|
||||
const eventSource = new EventSource('/api/trainprocess/logs');
|
||||
const eventSource = new EventSource(`/api/trainprocess/logs`);
|
||||
|
||||
eventSource.onmessage = (event) => {
|
||||
// Don't try to parse as JSON, just use the raw text data directly
|
||||
|
@ -366,6 +359,7 @@ export default function TrainingPage() {
|
|||
if (res.data.code === 0) {
|
||||
setTrainSuspended(false);
|
||||
resetTrainingState();
|
||||
localStorage.removeItem('trainingLogs');
|
||||
} else {
|
||||
throw new Error(res.data.message || 'Failed to reset progress');
|
||||
}
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
import type { TrainProgress } from '@/service/train';
|
||||
import type { IStepOutputInfo } from '../trainExposureModel';
|
||||
import TrainExposureModel from '../trainExposureModel';
|
||||
import { useState } from 'react';
|
||||
import classNames from 'classnames';
|
||||
|
||||
interface TrainingProgressProps {
|
||||
trainingProgress: TrainProgress;
|
||||
|
@ -16,6 +20,8 @@ const descriptionMap = [
|
|||
const TrainingProgress = (props: TrainingProgressProps) => {
|
||||
const { trainingProgress, status } = props;
|
||||
|
||||
const [stepOutputInfo, setStepOutputInfo] = useState<IStepOutputInfo>({} as IStepOutputInfo);
|
||||
|
||||
const formatUnderscoreToName = (_str: string) => {
|
||||
const str = _str || '';
|
||||
|
||||
|
@ -24,6 +30,13 @@ const TrainingProgress = (props: TrainingProgressProps) => {
|
|||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(' ');
|
||||
};
|
||||
|
||||
const formatToUnderscore = (str: string): string => {
|
||||
if (!str) return '';
|
||||
|
||||
return str.toLowerCase().replace(/\s+/g, '_');
|
||||
};
|
||||
|
||||
const trainingStages = trainingProgress.stages.map((stage, index) => {
|
||||
return { ...stage, description: descriptionMap[index] };
|
||||
});
|
||||
|
@ -220,15 +233,30 @@ const TrainingProgress = (props: TrainingProgressProps) => {
|
|||
)}
|
||||
</div>
|
||||
<span
|
||||
className={`text-xs ${
|
||||
className={classNames(
|
||||
'text-xs',
|
||||
stage.current_step &&
|
||||
formatUnderscoreToName(stage.current_step) == step.name
|
||||
formatUnderscoreToName(stage.current_step) == step.name
|
||||
? 'text-blue-600 font-medium'
|
||||
: 'text-gray-600'
|
||||
}`}
|
||||
// step.completed ? 'hover:text-green-600 cursor-pointer' : ''
|
||||
)}
|
||||
>
|
||||
{step.name}
|
||||
</span>
|
||||
{step.completed && step.have_output && (
|
||||
<span
|
||||
className="text-xs text-blue-500 underline cursor-pointer hover:text-blue-600"
|
||||
onClick={() => {
|
||||
setStepOutputInfo({
|
||||
stepName: formatToUnderscore(step.name),
|
||||
path: step.path
|
||||
});
|
||||
}}
|
||||
>
|
||||
View Resources
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
@ -242,6 +270,11 @@ const TrainingProgress = (props: TrainingProgressProps) => {
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<TrainExposureModel
|
||||
handleClose={() => setStepOutputInfo({} as IStepOutputInfo)}
|
||||
stepOutputInfo={stepOutputInfo}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
import type { TrainStepOutput } from '@/service/train';
|
||||
import { getStepOutputContent } from '@/service/train';
|
||||
import { Modal, Table } from 'antd';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
|
||||
import { tomorrow } from 'react-syntax-highlighter/dist/cjs/styles/prism';
|
||||
|
||||
export interface IStepOutputInfo {
|
||||
path?: string;
|
||||
stepName: string;
|
||||
}
|
||||
interface IProps {
|
||||
handleClose: () => void;
|
||||
stepOutputInfo?: IStepOutputInfo;
|
||||
}
|
||||
|
||||
const TrainExposureModel = (props: IProps) => {
|
||||
const { handleClose, stepOutputInfo } = props;
|
||||
const [outputContent, setOutputContent] = useState<TrainStepOutput | null>(null);
|
||||
const [loading, setLoading] = useState<boolean>(true);
|
||||
|
||||
useEffect(() => {
|
||||
if (!stepOutputInfo?.stepName) return;
|
||||
|
||||
setOutputContent(null);
|
||||
setLoading(true);
|
||||
|
||||
getStepOutputContent(stepOutputInfo.stepName)
|
||||
.then((res) => {
|
||||
if (res.data.code == 0) {
|
||||
const data = res.data.data;
|
||||
|
||||
setOutputContent(data);
|
||||
} else {
|
||||
console.error(res.data.message);
|
||||
}
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false);
|
||||
});
|
||||
}, [stepOutputInfo?.stepName]);
|
||||
|
||||
const renderOutputContent = () => {
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center w-full py-12">
|
||||
<div className="flex flex-col items-center space-y-4">
|
||||
<div className="relative w-12 h-12">
|
||||
<div className="absolute w-12 h-12 rounded-full border-2 border-gray-200" />
|
||||
<div
|
||||
className="absolute w-12 h-12 rounded-full border-2 border-t-blue-500 animate-spin"
|
||||
style={{ animationDuration: '1.2s' }}
|
||||
/>
|
||||
</div>
|
||||
<p className="text-gray-500 text-sm">loading...</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!outputContent) return 'There are no resources for this step at this time';
|
||||
|
||||
if (outputContent.file_type == 'json') {
|
||||
const showContent = JSON.stringify(outputContent.content, null, 2);
|
||||
|
||||
return (
|
||||
<SyntaxHighlighter
|
||||
customStyle={{
|
||||
backgroundColor: 'transparent',
|
||||
margin: 0,
|
||||
padding: 0
|
||||
}}
|
||||
language="json"
|
||||
style={tomorrow}
|
||||
>
|
||||
{showContent}
|
||||
</SyntaxHighlighter>
|
||||
);
|
||||
}
|
||||
|
||||
if (outputContent.file_type == 'parquet') {
|
||||
const columns = outputContent.columns.map((item, index) => ({
|
||||
title: item,
|
||||
dataIndex: item,
|
||||
key: index
|
||||
}));
|
||||
const data = outputContent.content;
|
||||
|
||||
return (
|
||||
<Table className="w-fit max-w-fit" columns={columns} dataSource={data} pagination={false} />
|
||||
);
|
||||
}
|
||||
|
||||
return 'There are no resources for this step at this time';
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
centered
|
||||
closable={false}
|
||||
footer={null}
|
||||
onCancel={handleClose}
|
||||
open={!!stepOutputInfo?.stepName}
|
||||
width={800}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
{stepOutputInfo?.path && (
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<span className="text-lg font-medium text-gray-900">{`path: ${stepOutputInfo.path}`}</span>
|
||||
</div>
|
||||
)}
|
||||
<div className="bg-[#f5f5f5] flex flex-col max-h-[600px] w-full overflow-scroll border border-[#e0e0e0] rounded p-4 font-mono text-sm leading-6 text-[#333] shadow-sm transition-all duration-300 ease-in-out">
|
||||
{renderOutputContent()}
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default TrainExposureModel;
|
|
@ -25,6 +25,8 @@ interface TrainStep {
|
|||
completed: boolean;
|
||||
name: string;
|
||||
status: StepStatus;
|
||||
path?: string;
|
||||
have_output?: boolean;
|
||||
}
|
||||
|
||||
interface TrainStage {
|
||||
|
@ -34,6 +36,17 @@ interface TrainStage {
|
|||
steps: TrainStep[];
|
||||
current_step: string | null;
|
||||
}
|
||||
export interface TrainStepJson {
|
||||
content: object[];
|
||||
file_type: 'json';
|
||||
}
|
||||
|
||||
export interface TrainStepParquet {
|
||||
columns: string[];
|
||||
content: object[];
|
||||
file_type: 'parquet';
|
||||
}
|
||||
export type TrainStepOutput = TrainStepJson | TrainStepParquet;
|
||||
|
||||
export type StageName =
|
||||
| 'downloading_the_base_model'
|
||||
|
@ -155,3 +168,10 @@ export const checkCudaAvailability = () => {
|
|||
url: '/api/kernel2/cuda/available'
|
||||
});
|
||||
};
|
||||
|
||||
export const getStepOutputContent = (stepName: string) => {
|
||||
return Request<CommonResponse<TrainStepOutput>>({
|
||||
method: 'get',
|
||||
url: `/api/trainprocess/step_output_content?step_name=${stepName}`
|
||||
});
|
||||
};
|
||||
|
|
|
@ -584,8 +584,8 @@ class L2DataProcessor:
|
|||
text=True,
|
||||
capture_output=True,
|
||||
)
|
||||
logger.error(f"subprocess.run graphrag index error: {result.stderr}")
|
||||
if result.stderr:
|
||||
logger.error(f"subprocess.run graphrag index error: {result.stderr}")
|
||||
raise RuntimeError("subprocess.run graphrag index error")
|
||||
except Exception as e:
|
||||
raise
|
||||
|
|
|
@ -1 +1 @@
|
|||
graphrag index --config lpm_kernel/L2/data_pipeline/graphrag_indexing/settings.yaml --root lpm_kernel/L2/data_pipeline/graphrag_indexing
|
||||
graphrag index --config lpm_kernel/L2/data_pipeline/graphrag_indexing/settings.yaml --root lpm_kernel/L2/data_pipeline/graphrag_indexing --method standard --logger none
|
||||
|
|
|
@ -445,35 +445,10 @@ def create_and_prepare_model(args, data_args, training_args, model_kwargs=None):
|
|||
if args.lora_target_modules != "all-linear"
|
||||
else args.lora_target_modules,
|
||||
)
|
||||
|
||||
# Load tokenizer - tokenizers are usually small and don't need memory management
|
||||
special_tokens = None
|
||||
chat_template = None
|
||||
if args.chat_template_format == "chatml":
|
||||
special_tokens = ChatmlSpecialTokens
|
||||
chat_template = DEFAULT_CHATML_CHAT_TEMPLATE
|
||||
elif args.chat_template_format == "zephyr":
|
||||
special_tokens = ZephyrSpecialTokens
|
||||
chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE
|
||||
|
||||
if special_tokens is not None:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
pad_token=special_tokens.pad_token.value,
|
||||
bos_token=special_tokens.bos_token.value,
|
||||
eos_token=special_tokens.eos_token.value,
|
||||
additional_special_tokens=special_tokens.list(),
|
||||
trust_remote_code=True,
|
||||
padding_side="right",
|
||||
)
|
||||
tokenizer.chat_template = chat_template
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
# Make sure pad_token is set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path, trust_remote_code=True, padding_side="right"
|
||||
)
|
||||
|
||||
# Apply Unsloth LoRA if requested and check memory status
|
||||
if args.use_unsloth:
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
@ -7,13 +6,9 @@ import torch # Add torch import for CUDA detection
|
|||
import traceback
|
||||
from dataclasses import asdict
|
||||
|
||||
from flask import Blueprint, jsonify, Response, request
|
||||
from flask import Blueprint, jsonify, request
|
||||
from flask_pydantic import validate
|
||||
|
||||
from lpm_kernel.L1.serializers import NotesStorage
|
||||
from lpm_kernel.L1.utils import save_true_topics
|
||||
from lpm_kernel.L2.l2_generator import L2Generator
|
||||
from lpm_kernel.L2.utils import save_hf_model
|
||||
from lpm_kernel.api.common.responses import APIResponse
|
||||
from lpm_kernel.api.domains.kernel2.dto.chat_dto import (
|
||||
ChatRequest,
|
||||
|
@ -24,20 +19,11 @@ from lpm_kernel.api.domains.kernel2.services.prompt_builder import (
|
|||
RoleBasedStrategy,
|
||||
KnowledgeEnhancedStrategy,
|
||||
)
|
||||
from lpm_kernel.api.domains.kernel2.services.role_service import role_service
|
||||
from lpm_kernel.api.domains.loads.services import LoadService
|
||||
from lpm_kernel.api.services.local_llm_service import local_llm_service
|
||||
from lpm_kernel.kernel.chunk_service import ChunkService
|
||||
from lpm_kernel.kernel.l1.l1_manager import (
|
||||
extract_notes_from_documents,
|
||||
document_service,
|
||||
get_latest_status_bio,
|
||||
get_latest_global_bio,
|
||||
)
|
||||
|
||||
from ...common.script_executor import ScriptExecutor
|
||||
from ...common.script_runner import ScriptRunner
|
||||
from ....configs.config import Config
|
||||
from ....kernel.note_service import NoteService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -70,45 +56,6 @@ def health_check():
|
|||
return jsonify(APIResponse.success(data={"status": "stopped"}))
|
||||
|
||||
|
||||
@kernel2_bp.route("/model/download", methods=["POST"])
|
||||
def downloadModel():
|
||||
"""Download base model
|
||||
|
||||
Request body:
|
||||
{
|
||||
"model_name": str # Model name, e.g. "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{
|
||||
"code": int,
|
||||
"message": str,
|
||||
"data": {
|
||||
"model_path": str # Model save path
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or "model_name" not in data:
|
||||
return jsonify(APIResponse.error(message="Missing required parameter: model_name", code=400))
|
||||
|
||||
model_name = data["model_name"]
|
||||
|
||||
# Download and save model
|
||||
model_path = save_hf_model(model_name)
|
||||
|
||||
return jsonify(APIResponse.success(
|
||||
data={"model_path": model_path},
|
||||
message="Model download completed"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to download model: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return jsonify(APIResponse.error(message=error_msg, code=500))
|
||||
|
||||
|
||||
@kernel2_bp.route("/username", methods=["GET"])
|
||||
def username():
|
||||
return jsonify(APIResponse.success(data={"username": LoadService.get_current_upload_name()}))
|
||||
|
@ -117,475 +64,6 @@ def username():
|
|||
@kernel2_bp.route("/docker/env", methods=["GET"])
|
||||
def docker_env():
|
||||
return jsonify(APIResponse.success(data={"in_docker_env": os.getenv("IN_DOCKER_ENV")}))
|
||||
|
||||
|
||||
|
||||
@kernel2_bp.route("/data/prepare", methods=["POST"])
|
||||
def all():
|
||||
def generate():
|
||||
try:
|
||||
# 1. Initialize configuration and directories (5%)
|
||||
progress_data = {
|
||||
"stage": "Initializing",
|
||||
"progress": 5,
|
||||
"message": "Initializing configuration and directories"
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
config = Config.from_env()
|
||||
base_dir = os.path.join(
|
||||
os.getcwd(), config.get("USER_DATA_PIPELINE_DIR") + "/raw_data"
|
||||
)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
|
||||
# 2. Process topics data (15%)
|
||||
progress_data = {
|
||||
"stage": "Processing Topics",
|
||||
"progress": 15,
|
||||
"message": "Saving topics data"
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
chunk_service = ChunkService()
|
||||
topics_data = chunk_service.query_topics_data()
|
||||
save_true_topics(topics_data, os.path.join(base_dir, "topics.json"))
|
||||
|
||||
# 3. Process documents and notes (35%)
|
||||
progress_data = {
|
||||
"stage": "Processing Documents",
|
||||
"progress": 35,
|
||||
"message": "Extracting and preparing document notes"
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
documents = document_service.list_documents_with_l0()
|
||||
notes_list, _ = extract_notes_from_documents(documents)
|
||||
if not notes_list:
|
||||
error_data = {
|
||||
"stage": "Error",
|
||||
"progress": -1,
|
||||
"message": "No notes found"
|
||||
}
|
||||
yield f"data: {json.dumps(error_data)}\n\n"
|
||||
return
|
||||
|
||||
note_service = NoteService()
|
||||
note_service.prepareNotes(notes_list)
|
||||
|
||||
storage = NotesStorage()
|
||||
result = storage.save_notes(notes_list)
|
||||
|
||||
# 4. Prepare configuration files and paths (50%)
|
||||
progress_data = {
|
||||
"stage": "Preparing Configuration",
|
||||
"progress": 50,
|
||||
"message": "Preparing L2 generator configuration "
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
config_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"resources/L2/data_pipeline/data_prep/subjective/config/config.json",
|
||||
)
|
||||
entitys_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"resources/L2/data_pipeline/raw_data/id_entity_mapping_subjective_v2.json",
|
||||
)
|
||||
graph_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"resources/L1/graphrag_indexing_output/subjective/entities.parquet",
|
||||
)
|
||||
|
||||
data_output_base_dir = os.path.join(os.getcwd(), "resources/L2/data")
|
||||
notes = storage.load_notes()
|
||||
|
||||
# 5. Prepare basic information (65%)
|
||||
progress_data = {
|
||||
"stage": "Preparing Basic Info",
|
||||
"progress": 65,
|
||||
"message": "Getting user information and bio"
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
status_bio = get_latest_status_bio()
|
||||
global_bio = get_latest_global_bio()
|
||||
|
||||
basic_info = {
|
||||
"username": LoadService.get_current_upload_name(),
|
||||
"aboutMe": LoadService.get_current_upload_description(),
|
||||
"statusBio": status_bio.content
|
||||
if status_bio
|
||||
else "Currently working on an AI project.",
|
||||
"globalBio": global_bio.content_third_view
|
||||
if global_bio
|
||||
else "The User is a software engineer who loves programming and learning new technologies.",
|
||||
"lang": "English",
|
||||
}
|
||||
|
||||
# 6. Data preprocessing (80%)
|
||||
progress_data = {
|
||||
"stage": "Data Preprocessing",
|
||||
"progress": 80,
|
||||
"message": "Executing data preprocessing"
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
l2_generator = L2Generator(
|
||||
data_path=os.path.join(os.getcwd(), "resources")
|
||||
)
|
||||
l2_generator.data_preprocess(notes, basic_info)
|
||||
|
||||
# 7. Generate subjective data (95%)
|
||||
progress_data = {
|
||||
"stage": "Generating Data",
|
||||
"progress": 95,
|
||||
"message": "Generating subjective data: Preference QA Self QA Diversity QA graphrag_indexing"
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
l2_generator.gen_subjective_data(
|
||||
notes,
|
||||
basic_info,
|
||||
data_output_base_dir,
|
||||
storage.topics_path,
|
||||
entitys_path,
|
||||
graph_path,
|
||||
config_path,
|
||||
)
|
||||
|
||||
# 8. Complete (100%)
|
||||
progress_data = {
|
||||
"stage": "Complete",
|
||||
"progress": 100,
|
||||
"message": "Data preparation completed",
|
||||
"result": {
|
||||
"bio": basic_info["globalBio"],
|
||||
"document_clusters": "Generated document clusters",
|
||||
"chunk_topics": "Generated chunk topics"
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data preparation failed: {str(e)}", exc_info=True)
|
||||
error_data = {
|
||||
"stage": "Error",
|
||||
"progress": -1,
|
||||
"message": f"Data preparation failed: {str(e)}"
|
||||
}
|
||||
yield f"data: {json.dumps(error_data)}\n\n"
|
||||
|
||||
return Response(
|
||||
generate(),
|
||||
mimetype='text/event-stream',
|
||||
headers={
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
'X-Accel-Buffering': 'no' # Disable Nginx buffering
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Global variables for tracking training process
|
||||
_training_process = None
|
||||
_training_thread = None
|
||||
_stopping_training = False
|
||||
|
||||
|
||||
def get_model_paths(model_name: str) -> dict:
|
||||
"""
|
||||
Get all paths related to the model
|
||||
|
||||
Args:
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Dictionary containing all related paths:
|
||||
- base_path: Base model path
|
||||
- personal_dir: Personal trained model output directory
|
||||
- merged_dir: Merged model output directory
|
||||
"""
|
||||
base_dir = os.getcwd()
|
||||
paths = {
|
||||
"base_path": os.path.join(base_dir, "resources/L2/base_models", model_name),
|
||||
"personal_dir": os.path.join(base_dir, "resources/model/output/personal_model", model_name),
|
||||
"merged_dir": os.path.join(base_dir, "resources/model/output/merged_model", model_name),
|
||||
"gguf_dir": os.path.join(base_dir, "resources/model/output/gguf", model_name)
|
||||
}
|
||||
|
||||
# Ensure all directories exist
|
||||
for path in paths.values():
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
def start_training(script_path: str, log_path: str) -> None:
|
||||
"""Start training in a new thread"""
|
||||
global _training_process
|
||||
try:
|
||||
# Use ScriptRunner to execute the script
|
||||
runner = ScriptRunner(log_path=log_path)
|
||||
_training_process = runner.execute_script(
|
||||
script_path=script_path,
|
||||
script_type="training",
|
||||
is_python=False, # This is a bash script
|
||||
)
|
||||
|
||||
logger.info(f"Training process started: {_training_process}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start training process: {str(e)}")
|
||||
_training_process = None
|
||||
raise
|
||||
|
||||
|
||||
@kernel2_bp.route("/train2", methods=["POST"])
|
||||
def train2():
|
||||
"""Start model training"""
|
||||
global _training_thread, _training_process, _stopping_training
|
||||
|
||||
try:
|
||||
# Get request parameters
|
||||
data = request.get_json()
|
||||
if not data or "model_name" not in data:
|
||||
return jsonify(APIResponse.error(message="Missing required parameter: model_name", code=400))
|
||||
|
||||
model_name = data["model_name"]
|
||||
paths = get_model_paths(model_name)
|
||||
|
||||
# Get optional parameters with defaults
|
||||
learning_rate = data.get("learning_rate", 2e-4)
|
||||
num_train_epochs = data.get("number_of_epochs", 3)
|
||||
concurrency_threads = data.get("concurrency_threads", 2)
|
||||
data_synthesis_mode = data.get("data_synthesis_mode", "low")
|
||||
use_cuda = data.get("use_cuda", False)
|
||||
|
||||
# Convert use_cuda to string "True" or "False" for the shell script
|
||||
use_cuda_str = "True" if use_cuda else "False"
|
||||
|
||||
logger.info(f"Training configuration: learning_rate={learning_rate}, epochs={num_train_epochs}, "
|
||||
f"threads={concurrency_threads}, mode={data_synthesis_mode}, use_cuda={use_cuda} ({use_cuda_str})")
|
||||
|
||||
# Check if model exists
|
||||
if not os.path.exists(paths["base_path"]):
|
||||
return jsonify(APIResponse.error(
|
||||
message=f"Model '{model_name}' does not exist, please download first",
|
||||
code=400
|
||||
))
|
||||
|
||||
# Check if training is already running
|
||||
if _training_thread and _training_thread.is_alive():
|
||||
return jsonify(APIResponse.error("Training task is already running"))
|
||||
|
||||
# Reset stopping flag
|
||||
_stopping_training = False
|
||||
|
||||
# Prepare log directory and file
|
||||
log_dir = os.path.join(os.getcwd(), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_path = os.path.join(log_dir, "train.log")
|
||||
logger.info(f"Log file path: {log_path}")
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(paths["personal_dir"], exist_ok=True)
|
||||
|
||||
# Set environment variables
|
||||
os.environ["MODEL_BASE_PATH"] = paths["base_path"]
|
||||
os.environ["MODEL_PERSONAL_DIR"] = paths["personal_dir"]
|
||||
# Assign
|
||||
os.environ["USER_NAME"] = LoadService.get_current_upload_name()
|
||||
|
||||
logger.info(f"Environment variables set: {os.environ}")
|
||||
|
||||
script_path = os.path.join(os.getcwd(), "lpm_kernel/L2/train_for_user.sh")
|
||||
|
||||
# Build command arguments
|
||||
cmd_args = [
|
||||
"--lr", str(learning_rate),
|
||||
"--epochs", str(num_train_epochs),
|
||||
"--threads", str(concurrency_threads),
|
||||
"--mode", str(data_synthesis_mode),
|
||||
"--cuda", use_cuda_str # Use the properly formatted string
|
||||
]
|
||||
|
||||
# Start training
|
||||
import threading
|
||||
_training_thread = threading.Thread(
|
||||
target=start_training_with_args,
|
||||
args=(script_path, log_path, cmd_args),
|
||||
daemon=True
|
||||
)
|
||||
_training_thread.start()
|
||||
|
||||
return jsonify(APIResponse.success(
|
||||
data={
|
||||
"status": "training_started",
|
||||
"model_name": model_name,
|
||||
"log_path": log_path,
|
||||
},
|
||||
message="Training task started successfully"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting training task: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return jsonify(APIResponse.error(message=f"Failed to start training: {str(e)}"))
|
||||
|
||||
|
||||
def start_training_with_args(script_path: str, log_path: str, args: list) -> None:
|
||||
"""Start training with additional arguments"""
|
||||
global _training_process
|
||||
try:
|
||||
# Convert script path and args to a command
|
||||
cmd = [script_path] + args
|
||||
|
||||
# Use ScriptRunner to execute the script
|
||||
runner = ScriptRunner(log_path=log_path)
|
||||
_training_process = runner.execute_script(
|
||||
script_path=script_path,
|
||||
script_type="training",
|
||||
is_python=False, # This is a bash script
|
||||
args=args
|
||||
)
|
||||
|
||||
logger.info(f"Training process started with args: {args}, process: {_training_process}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start training process: {str(e)}")
|
||||
_training_process = None
|
||||
raise
|
||||
|
||||
|
||||
@kernel2_bp.route("/merge_weights", methods=["POST"])
|
||||
def merge_weights():
|
||||
"""Merge model weights"""
|
||||
try:
|
||||
# Get request parameters
|
||||
data = request.get_json()
|
||||
if not data or "model_name" not in data:
|
||||
return jsonify(APIResponse.error(message="Missing required parameter: model_name", code=400))
|
||||
|
||||
model_name = data["model_name"]
|
||||
paths = get_model_paths(model_name)
|
||||
|
||||
# Check if model exists
|
||||
if not os.path.exists(paths["base_path"]):
|
||||
return jsonify(APIResponse.error(
|
||||
message=f"Model '{model_name}' does not exist, please download first",
|
||||
code=400
|
||||
))
|
||||
|
||||
# Check if training output exists
|
||||
if not os.path.exists(paths["personal_dir"]):
|
||||
return jsonify(APIResponse.error(
|
||||
message=f"Model '{model_name}' training output does not exist, please train model first",
|
||||
code=400
|
||||
))
|
||||
|
||||
# Ensure merged output directory exists
|
||||
os.makedirs(paths["merged_dir"], exist_ok=True)
|
||||
|
||||
# Set environment variables
|
||||
os.environ["MODEL_BASE_PATH"] = paths["base_path"]
|
||||
os.environ["MODEL_PERSONAL_DIR"] = paths["personal_dir"]
|
||||
os.environ["MODEL_MERGED_DIR"] = paths["merged_dir"]
|
||||
|
||||
logger.info(f"Environment variables set: MODEL_BASE_PATH : {os.environ}")
|
||||
|
||||
script_path = os.path.join(
|
||||
os.getcwd(), "lpm_kernel/L2/merge_weights_for_user.sh"
|
||||
)
|
||||
log_path = os.path.join(os.getcwd(), "logs", f"merge_weights_{model_name}.log")
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
|
||||
# Use script executor to execute merge script
|
||||
result = script_executor.execute(
|
||||
script_path=script_path, script_type="merge_weights", log_file=log_path
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
APIResponse.success(
|
||||
data={
|
||||
**result,
|
||||
"model_name": model_name,
|
||||
"log_path": log_path,
|
||||
"personal_dir": paths["personal_dir"],
|
||||
"merged_dir": paths["merged_dir"]
|
||||
},
|
||||
message="Weight merge task started"
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to start weight merge: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return jsonify(APIResponse.error(message=error_msg, code=500))
|
||||
|
||||
|
||||
@kernel2_bp.route("/convert_model", methods=["POST"])
|
||||
def convert_model():
|
||||
"""Convert model to GGUF format"""
|
||||
try:
|
||||
# Get request parameters
|
||||
data = request.get_json()
|
||||
logger.info(f"Request parameters: {data}")
|
||||
if not data or "model_name" not in data:
|
||||
return jsonify(APIResponse.error(message="Missing required parameter: model_name", code=400))
|
||||
|
||||
model_name = data["model_name"]
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
paths = get_model_paths(model_name)
|
||||
|
||||
# Check if merged model exists
|
||||
merged_model_dir = paths["merged_dir"]
|
||||
logger.info(f"Merged model path: {merged_model_dir}")
|
||||
if not os.path.exists(merged_model_dir):
|
||||
return jsonify(APIResponse.error(
|
||||
message=f"Model '{model_name}' merged output does not exist, please merge model first",
|
||||
code=400
|
||||
))
|
||||
|
||||
# Get GGUF output directory
|
||||
gguf_dir = paths["gguf_dir"]
|
||||
logger.info(f"GGUF output directory: {gguf_dir}")
|
||||
|
||||
script_path = os.path.join(os.getcwd(), "lpm_kernel/L2/convert_hf_to_gguf.py")
|
||||
gguf_path = os.path.join(gguf_dir, "model.gguf")
|
||||
logger.info(f"GGUF output path: {gguf_path}")
|
||||
|
||||
# Build parameters
|
||||
args = [
|
||||
merged_model_dir,
|
||||
"--outfile",
|
||||
gguf_path,
|
||||
"--outtype",
|
||||
"f16",
|
||||
]
|
||||
logger.info(f"Parameters: {args}")
|
||||
# Use script executor to execute conversion script
|
||||
result = script_executor.execute(
|
||||
script_path=script_path, script_type="convert_model", args=args
|
||||
)
|
||||
|
||||
logger.info(f"Model conversion successful: {result}")
|
||||
return jsonify(APIResponse.success(
|
||||
data={
|
||||
**result,
|
||||
"model_name": model_name,
|
||||
"merged_dir": merged_model_dir,
|
||||
"gguf_path": gguf_path
|
||||
},
|
||||
message="Model conversion task started"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to start model conversion: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return jsonify(APIResponse.error(message=error_msg, code=500))
|
||||
|
||||
|
||||
@kernel2_bp.route("/llama/start", methods=["POST"])
|
||||
def start_llama_server():
|
||||
|
@ -599,9 +77,9 @@ def start_llama_server():
|
|||
model_name = data["model_name"]
|
||||
# Get optional use_gpu parameter with default value of True
|
||||
use_gpu = data.get("use_gpu", True)
|
||||
|
||||
paths = get_model_paths(model_name)
|
||||
gguf_path = os.path.join(paths["gguf_dir"], "model.gguf")
|
||||
base_dir = os.getcwd()
|
||||
model_dir = os.path.join(base_dir, "resources/model/output/gguf", model_name)
|
||||
gguf_path = os.path.join(model_dir, "model.gguf")
|
||||
|
||||
server_path = os.path.join(os.getcwd(), "llama.cpp/build/bin")
|
||||
if os.path.exists(os.path.join(os.getcwd(), "llama.cpp/build/bin/Release")):
|
||||
|
@ -654,7 +132,6 @@ def start_llama_server():
|
|||
# Flag to track if service is stopping
|
||||
_stopping_server = False
|
||||
|
||||
|
||||
@kernel2_bp.route("/llama/stop", methods=["POST"])
|
||||
def stop_llama_server():
|
||||
"""Stop llama-server service - Force immediate termination of the process"""
|
||||
|
@ -706,28 +183,6 @@ def get_llama_server_status():
|
|||
logger.error(f"Error getting llama-server status: {str(e)}", exc_info=True)
|
||||
return APIResponse.error(f"Error getting llama-server status: {str(e)}")
|
||||
|
||||
|
||||
@kernel2_bp.route("/test/version", methods=["GET"])
|
||||
def test_version():
|
||||
"""Test environment version"""
|
||||
try:
|
||||
# Execute python command directly to get version
|
||||
result = script_executor.execute(
|
||||
script_path="python", script_type="version_check", args=["--version"]
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
APIResponse.success(
|
||||
data={"python_version": result}, message="Version information obtained successfully"
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to get version information: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return jsonify(APIResponse.error(error_msg))
|
||||
|
||||
|
||||
@kernel2_bp.route("/chat", methods=["POST"])
|
||||
@validate()
|
||||
def chat(body: ChatRequest):
|
||||
|
@ -790,7 +245,7 @@ def chat(body: ChatRequest):
|
|||
}
|
||||
# Return as regular JSON response for non-stream or stream-compatible error
|
||||
if not body.stream:
|
||||
return APIResponse.error(message="服务暂时不可用", code=503), 503
|
||||
return APIResponse.error(message="Service temporarily unavailable", code=503), 503
|
||||
return local_llm_service.handle_stream_response(iter([error_response]))
|
||||
|
||||
try:
|
||||
|
|
|
@ -445,14 +445,6 @@ class LoadService:
|
|||
|
||||
logger.info("Reset default training progress")
|
||||
|
||||
# Reset global training process variables
|
||||
from lpm_kernel.api.domains.kernel2.routes_l2 import _training_process, _training_thread, _stopping_training
|
||||
if _training_process is not None:
|
||||
logger.info("Resetting global training process variables")
|
||||
_training_process = None
|
||||
_training_thread = None
|
||||
_stopping_training = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset training progress objects: {str(e)}")
|
||||
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Optional, Dict
|
||||
from lpm_kernel.models.l1 import L1Bio, L1Shade, L1Cluster, L1ChunkTopic
|
||||
from lpm_kernel.common.repository.database_session import DatabaseSession
|
||||
|
||||
# Output file mapping for each process step
|
||||
output_files = {
|
||||
"extract_dimensional_topics": os.path.join(os.getcwd(), "resources/L2/data_pipeline/raw_data/topics.json"),
|
||||
"map_your_entity_network": os.path.join(os.getcwd(), "resources/L1/graphrag_indexing_output/subjective/entities.parquet"),
|
||||
"decode_preference_patterns": os.path.join(os.getcwd(), "resources/L2/data/preference.json"),
|
||||
"reinforce_identity": os.path.join(os.getcwd(), "resources/L2/data/selfqa.json"),
|
||||
"augment_content_retention": os.path.join(os.getcwd(), "resources/L2/data/diversity.json"),
|
||||
}
|
||||
|
||||
def query_l1_version_data(version: int) -> dict:
|
||||
"""
|
||||
Query L1 bio and shades for a given version and return as dict.
|
||||
"""
|
||||
with DatabaseSession.session() as session:
|
||||
# Get all data for this version
|
||||
bio = session.query(L1Bio).filter(L1Bio.version == version).first()
|
||||
|
||||
shades = session.query(L1Shade).filter(L1Shade.version == version).all()
|
||||
|
||||
clusters = (
|
||||
session.query(L1Cluster).filter(L1Cluster.version == version).all()
|
||||
)
|
||||
|
||||
chunk_topics = (
|
||||
session.query(L1ChunkTopic)
|
||||
.filter(L1ChunkTopic.version == version)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not bio:
|
||||
return jsonify(APIResponse.error(f"Version {version} not found"))
|
||||
|
||||
# Build response data
|
||||
data = {
|
||||
"file_type": "json",
|
||||
"content": {
|
||||
"version": version,
|
||||
"bio": {
|
||||
"content": bio.content,
|
||||
"content_third_view": bio.content_third_view,
|
||||
"summary": bio.summary,
|
||||
"summary_third_view": bio.summary_third_view,
|
||||
"shades": [
|
||||
{
|
||||
"name": s.name,
|
||||
"aspect": s.aspect,
|
||||
"icon": s.icon,
|
||||
"desc_third_view": s.desc_third_view,
|
||||
"content_third_view": s.content_third_view,
|
||||
"desc_second_view": s.desc_second_view,
|
||||
"content_second_view": s.content_second_view,
|
||||
}
|
||||
for s in shades
|
||||
],
|
||||
},
|
||||
"clusters": [
|
||||
{
|
||||
"cluster_id": c.cluster_id,
|
||||
"memory_ids": c.memory_ids,
|
||||
"cluster_center": c.cluster_center,
|
||||
}
|
||||
for c in clusters
|
||||
],
|
||||
"chunk_topics": [
|
||||
{"chunk_id": t.chunk_id, "topic": t.topic, "tags": t.tags}
|
||||
for t in chunk_topics
|
||||
],
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
def read_file_content(file_path: str) -> Optional[Dict]:
|
||||
"""Read content from a file based on its type."""
|
||||
try:
|
||||
if file_path.endswith(".json"):
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = json.load(f)
|
||||
return {
|
||||
"file_type": "json",
|
||||
"content": content
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {file_path}: {str(e)}")
|
||||
return None
|
||||
elif file_path.endswith(".parquet"):
|
||||
return read_parquet_file(file_path)
|
||||
else:
|
||||
print(f"Unsupported file type for {file_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def read_parquet_file(file_path: str) -> Optional[Dict]:
|
||||
"""
|
||||
Read a parquet file, convert numpy types for JSON serialization, and return file metadata and content.
|
||||
"""
|
||||
try:
|
||||
class NumpyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return super(NumpyEncoder, self).default(obj)
|
||||
|
||||
df = pd.read_parquet(file_path)
|
||||
# Remove columns named 'x' and 'y' if they exist
|
||||
df = df.drop(columns=[col for col in ['x', 'y'] if col in df.columns])
|
||||
df_dict = df.to_dict(orient='records')
|
||||
json_str = json.dumps(df_dict, cls=NumpyEncoder)
|
||||
records = json.loads(json_str)
|
||||
return {
|
||||
"file_type": "parquet",
|
||||
"rows": len(df),
|
||||
"columns": list(df.columns),
|
||||
"size_bytes": os.path.getsize(file_path),
|
||||
"content": records
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error reading parquet file {file_path}: {str(e)}")
|
||||
return None
|
|
@ -126,7 +126,6 @@ def start_process():
|
|||
logger.error(f"Training process failed: {str(e)}")
|
||||
return jsonify(APIResponse.error(message=f"Training process error: {str(e)}"))
|
||||
|
||||
|
||||
@trainprocess_bp.route("/logs", methods=["GET"])
|
||||
def stream_logs():
|
||||
"""Get training logs in real-time"""
|
||||
|
@ -149,6 +148,8 @@ def stream_logs():
|
|||
yield f"data: {line.strip()}\n\n"
|
||||
|
||||
last_position = log_file.tell()
|
||||
if not new_lines:
|
||||
yield f":heartbeat\n\n"
|
||||
except Exception as e:
|
||||
# If file reading fails, record error and continue
|
||||
yield f"data: Error reading log file: {str(e)}\n\n"
|
||||
|
@ -166,7 +167,6 @@ def stream_logs():
|
|||
}
|
||||
)
|
||||
|
||||
|
||||
@trainprocess_bp.route("/progress/<model_name>", methods=["GET"])
|
||||
def get_progress(model_name):
|
||||
"""Get current progress (non-real-time)"""
|
||||
|
@ -244,6 +244,41 @@ def stop_training():
|
|||
logger.error(f"Error stopping training process: {str(e)}", exc_info=True)
|
||||
return jsonify(APIResponse.error(message=f"Error stopping training process: {str(e)}"))
|
||||
|
||||
@trainprocess_bp.route("/step_output_content", methods=["GET"])
|
||||
def get_step_output_content():
|
||||
"""
|
||||
Get content of output file for a specific training step
|
||||
|
||||
Request parameters:
|
||||
step_name: Name of the step to get content for, e.g. 'extract_dimensional_topics'
|
||||
|
||||
Returns:
|
||||
Response: JSON response
|
||||
{
|
||||
"code": 0,
|
||||
"message": "Success",
|
||||
"data": {...} // Content of the output file, or null if not found
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# Get TrainProcessService instance
|
||||
train_service = TrainProcessService.get_instance()
|
||||
if train_service is None:
|
||||
logger.error("No active training process found.")
|
||||
return jsonify(APIResponse.error(message="No active training process found."))
|
||||
|
||||
# Get step name from query parameters
|
||||
step_name = request.args.get('step_name')
|
||||
if not step_name:
|
||||
return jsonify(APIResponse.error(message="Missing required parameter: step_name", code=400))
|
||||
|
||||
# Get step output content
|
||||
output_content = train_service.get_step_output_content(step_name)
|
||||
logger.info(f"Step output content: {output_content}")
|
||||
return jsonify(APIResponse.success(data=output_content))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get step output content: {str(e)}", exc_info=True)
|
||||
return jsonify(APIResponse.error(message=f"Failed to get step output content: {str(e)}"))
|
||||
|
||||
@trainprocess_bp.route("/training_params", methods=["GET"])
|
||||
def get_training_params():
|
||||
|
|
|
@ -19,7 +19,9 @@ class TrainProgress:
|
|||
{
|
||||
"name": "Model Download",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": False,
|
||||
"path": None
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -32,22 +34,30 @@ class TrainProgress:
|
|||
{
|
||||
"name": "List Documents",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": False,
|
||||
"path": None
|
||||
},
|
||||
{
|
||||
"name": "Generate Document Embeddings",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": False,
|
||||
"path": None
|
||||
},
|
||||
{
|
||||
"name": "Process Chunks",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": False,
|
||||
"path": None
|
||||
},
|
||||
{
|
||||
"name": "Chunk Embedding",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": False,
|
||||
"path": None
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -60,17 +70,23 @@ class TrainProgress:
|
|||
{
|
||||
"name": "Extract Dimensional Topics",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": True,
|
||||
"path": "resources/L2/data_pipeline/raw_data/topics.json"
|
||||
},
|
||||
{
|
||||
"name": "Generate Biography",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": True,
|
||||
"path": "From database"
|
||||
},
|
||||
{
|
||||
"name": "Map Your Entity Network",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": True,
|
||||
"path": "resources/L1/graphrag_indexing_output/subjective/entities.parquet"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -83,17 +99,23 @@ class TrainProgress:
|
|||
{
|
||||
"name": "Decode Preference Patterns",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": True,
|
||||
"path": "resources/L2/data/preference.json"
|
||||
},
|
||||
{
|
||||
"name": "Reinforce Identity",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": True,
|
||||
"path": "resources/L2/data/selfqa.json"
|
||||
},
|
||||
{
|
||||
"name": "Augment Content Retention",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": True,
|
||||
"path": "resources/L2/data/diversity.json"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -106,17 +128,23 @@ class TrainProgress:
|
|||
{
|
||||
"name": "Train",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": False,
|
||||
"path": None
|
||||
},
|
||||
{
|
||||
"name": "Merge Weights",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": False,
|
||||
"path": None
|
||||
},
|
||||
{
|
||||
"name": "Convert Model",
|
||||
"completed": False,
|
||||
"status": "pending"
|
||||
"status": "pending",
|
||||
"have_output": False,
|
||||
"path": None
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import re
|
||||
import time
|
||||
import psutil
|
||||
from typing import Optional, Dict
|
||||
from lpm_kernel.L1.utils import save_true_topics
|
||||
from lpm_kernel.L1.serializers import NotesStorage
|
||||
from lpm_kernel.kernel.note_service import NoteService
|
||||
|
@ -25,8 +26,10 @@ from lpm_kernel.api.domains.trainprocess.progress_enum import Status
|
|||
from lpm_kernel.api.domains.trainprocess.process_step import ProcessStep
|
||||
from lpm_kernel.api.domains.trainprocess.progress_holder import TrainProgressHolder
|
||||
from lpm_kernel.api.domains.trainprocess.training_params_manager import TrainingParamsManager
|
||||
from lpm_kernel.models.l1 import L1Bio, L1Shade
|
||||
from lpm_kernel.common.repository.database_session import DatabaseSession
|
||||
from lpm_kernel.api.domains.kernel.routes import store_l1_data
|
||||
from lpm_kernel.api.domains.trainprocess.L1_exposure_manager import output_files, query_l1_version_data, read_file_content
|
||||
import gc
|
||||
import subprocess
|
||||
from lpm_kernel.configs.logging import get_train_process_logger, TRAIN_LOG_FILE
|
||||
|
@ -1135,6 +1138,35 @@ class TrainProcessService:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to save progress: {str(e)}", exc_info=True)
|
||||
|
||||
def get_step_output_content(self, step_name: str = None) -> Optional[Dict]:
|
||||
"""Get content of output file for a specific training step
|
||||
|
||||
Args:
|
||||
step_name: Name of the step to get content for. Required parameter.
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: Content of the output file for the specified step, or None if not found
|
||||
"""
|
||||
try:
|
||||
if step_name == "generate_biography":
|
||||
logger.info("Querying L1 version data for biography")
|
||||
return query_l1_version_data(1)
|
||||
|
||||
# If step_name is not provided or invalid, return None
|
||||
if not step_name or step_name not in output_files:
|
||||
return None
|
||||
|
||||
# Get file path for the requested step
|
||||
file_path = output_files[step_name]
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
|
||||
# Read and return file content
|
||||
return read_file_content(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting step output content: {str(e)}")
|
||||
return None
|
||||
|
||||
def stop_process(self):
|
||||
"""Stop training process
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
class TrainingTags(BaseModel):
|
||||
model_name: str
|
||||
is_cot: bool = False
|
||||
document_count: int = Field(ge=0, default=0)
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allows additional fields for extensibility
|
||||
validate_assignment = True
|
|
@ -1,5 +1,7 @@
|
|||
import aiohttp
|
||||
import logging
|
||||
from lpm_kernel.api.domains.upload.TrainingTags import TrainingTags
|
||||
from lpm_kernel.configs import config
|
||||
import websockets
|
||||
import json
|
||||
import asyncio
|
||||
|
@ -69,7 +71,7 @@ class RegistryClient:
|
|||
"""
|
||||
return f"{self.ws_url}/api/ws/{instance_id}?password={instance_password}"
|
||||
|
||||
def register_upload(self, upload_name: str, instance_id: str = None, description: str = None, email: str = None):
|
||||
def register_upload(self, upload_name: str, instance_id: str = None, description: str = None, email: str = None, tags: TrainingTags = None):
|
||||
"""
|
||||
Register Upload instance with the registry center
|
||||
|
||||
|
@ -83,6 +85,7 @@ class RegistryClient:
|
|||
Registration data
|
||||
"""
|
||||
headers = self._get_auth_header()
|
||||
tags_dict = tags.model_dump() if tags else None
|
||||
response = requests.post(
|
||||
f"{self.server_url}/api/upload/register",
|
||||
headers=headers,
|
||||
|
@ -90,7 +93,8 @@ class RegistryClient:
|
|||
"upload_name": upload_name,
|
||||
"instance_id": instance_id,
|
||||
"description": description,
|
||||
"email": email
|
||||
"email": email,
|
||||
"tags": tags_dict
|
||||
}
|
||||
)
|
||||
return ResponseHandler.handle_response(
|
||||
|
|
|
@ -11,6 +11,9 @@ from lpm_kernel.api.domains.loads.load_service import LoadService
|
|||
from .client import RegistryClient
|
||||
import threading
|
||||
from lpm_kernel.api.domains.loads.dto import LoadDTO
|
||||
from lpm_kernel.api.domains.trainprocess.training_params_manager import TrainingParamsManager
|
||||
from lpm_kernel.file_data.document_service import document_service
|
||||
from lpm_kernel.api.domains.upload.TrainingTags import TrainingTags
|
||||
|
||||
upload_bp = Blueprint("upload", __name__)
|
||||
registry_client = RegistryClient()
|
||||
|
@ -27,9 +30,18 @@ def register_upload():
|
|||
instance_id = current_load.instance_id
|
||||
email = current_load.email
|
||||
description = current_load.description
|
||||
params = TrainingParamsManager.get_latest_training_params()
|
||||
model_name = params.get("model_name")
|
||||
is_cot = params.get("is_cot")
|
||||
document_count = len(document_service.list_documents())
|
||||
tags = TrainingTags(
|
||||
model_name=model_name,
|
||||
is_cot=is_cot,
|
||||
document_count=document_count
|
||||
)
|
||||
|
||||
result = registry_client.register_upload(
|
||||
upload_name, instance_id, description, email
|
||||
upload_name, instance_id, description, email, tags
|
||||
)
|
||||
|
||||
instance_id_new = result.get("instance_id")
|
||||
|
|
Loading…
Reference in New Issue