133 lines
4.6 KiB
Go
133 lines
4.6 KiB
Go
package task
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
|
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
|
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
|
|
exectsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor/task"
|
|
"gitlink.org.cn/cloudream/scheduler/common/utils"
|
|
"gitlink.org.cn/cloudream/scheduler/executor/internal/config"
|
|
"gitlink.org.cn/cloudream/scheduler/executor/internal/task/create_ecs"
|
|
"io/ioutil"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
type SchedulerDataPreprocess struct {
|
|
*exectsk.SchedulerDataPreprocess
|
|
}
|
|
|
|
func NewSchedulerDataPreprocess(info *exectsk.SchedulerDataPreprocess) *SchedulerDataPreprocess {
|
|
return &SchedulerDataPreprocess{info}
|
|
}
|
|
|
|
func (t *SchedulerDataPreprocess) Execute(task *Task, ctx TaskContext) {
|
|
log := logger.WithType[SchedulerDataPreprocess]("Task")
|
|
log.Debugf("begin")
|
|
defer log.Debugf("end")
|
|
|
|
err := t.do(task, ctx)
|
|
if err != nil {
|
|
log.Error(err)
|
|
return
|
|
}
|
|
|
|
}
|
|
|
|
func (t *SchedulerDataPreprocess) do(task *Task, ctx TaskContext) error {
|
|
// 设置环境变量
|
|
commands := utils.ConvertEnvsToCommand(t.Envs)
|
|
|
|
if t.ObjectStorage.MountType == schsdk.RcloneMount {
|
|
// 获取Rclone挂载命令
|
|
mountCommands := utils.GetRcloneCommands(t.ObjectStorage, t.UserID, schsdk.MountDir)
|
|
commands = append(commands, mountCommands...)
|
|
}
|
|
|
|
// 获取数据预处理命令
|
|
datePreprocessCommands, err := getDataPreprocessCommands(t.Envs, schglb.InferencePlatform)
|
|
if err != nil {
|
|
task.SendStatus(exectsk.NewSchedulerDataPreprocessStatus("", err))
|
|
return err
|
|
}
|
|
commands = append(commands, datePreprocessCommands...)
|
|
|
|
// 添加用户自定义命令
|
|
arr := utils.SplitCommands(t.CMD)
|
|
commands = append(commands, arr...)
|
|
|
|
factory := create_ecs.GetFactory(config.CloudName)
|
|
provider := factory.CreateProvider()
|
|
|
|
// 创建服务器
|
|
instanceID, ecsIP, err := provider.CreateServer()
|
|
//instanceID, ecsIP, err := "i-bp16imo8en907iy1oixd", "120.55.45.90", error(nil)
|
|
if err != nil {
|
|
task.SendStatus(exectsk.NewSchedulerDataPreprocessStatus("", err))
|
|
return err
|
|
}
|
|
logger.Info("run ECS success, instance id: " + instanceID + ", ip: " + ecsIP)
|
|
|
|
// 执行数据预处理命令
|
|
_, err = provider.RunCommand(commands, instanceID, 2000)
|
|
if err != nil {
|
|
task.SendStatus(exectsk.NewSchedulerDataPreprocessStatus("", err))
|
|
return err
|
|
}
|
|
|
|
task.SendStatus(exectsk.NewSchedulerDataPreprocessStatus(instanceID, err))
|
|
return nil
|
|
}
|
|
|
|
func getDataPreprocessCommands(envs []schsdk.KVPair, inferencePlatform schsdk.InferencePlatform) ([]string, error) {
|
|
|
|
if inferencePlatform.PlatformName == "" {
|
|
return nil, errors.New("inferencePlatform.PlatformName is empty")
|
|
}
|
|
|
|
var commands []string
|
|
|
|
// 读取预置的脚本
|
|
currentDir, err := filepath.Abs(".")
|
|
if err != nil {
|
|
fmt.Println("Error getting current directory:", err)
|
|
}
|
|
parentDir := filepath.Dir(currentDir)
|
|
fileName := "./scripts/data_preprocess.py"
|
|
filePath := filepath.Join(parentDir, fileName)
|
|
data, err := ioutil.ReadFile(filePath)
|
|
if err != nil {
|
|
fmt.Println("Error reading file:", err)
|
|
}
|
|
fileContent := string(data)
|
|
|
|
fileContent = strings.ReplaceAll(fileContent, "@base_url@", inferencePlatform.ApiBaseUrl)
|
|
fileContent = strings.ReplaceAll(fileContent, "@api_key@", inferencePlatform.ApiKey)
|
|
inputPath := schsdk.MountDir + "/" + envs[0].Value
|
|
inputPath = utils.HandlePath(inputPath)
|
|
fileContent = strings.ReplaceAll(fileContent, "@input_file_path@", inputPath)
|
|
outputPath := schsdk.MountDir + "/" + envs[1].Value
|
|
outputPath = utils.HandlePath(outputPath)
|
|
fileContent = strings.ReplaceAll(fileContent, "@output_file_path@", outputPath)
|
|
fileContent = strings.ReplaceAll(fileContent, "@chunk_max_length@", inferencePlatform.ChunkMaxLength)
|
|
fileContent = strings.ReplaceAll(fileContent, "@start_chunk_threshold@", inferencePlatform.StartChunkThreshold)
|
|
fileContent = strings.ReplaceAll(fileContent, "@similarity_threshold@", inferencePlatform.SimilarityThreshold)
|
|
fileContent = strings.ReplaceAll(fileContent, "@entries_per_file@", inferencePlatform.EntriesPerFile)
|
|
|
|
commandContent := "echo -e '" + fileContent + "' > /opt/generate_data.py"
|
|
commands = append(commands, commandContent)
|
|
|
|
commandContent = "source $HOME/miniconda3/bin/activate myenv && python -m pip install --upgrade pip && pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple && pip install PyMuPDF==1.19.0 scipy openai numpy backoff"
|
|
commands = append(commands, commandContent)
|
|
commandContent = "source $HOME/miniconda3/bin/activate myenv && python /opt/generate_data.py"
|
|
commands = append(commands, commandContent)
|
|
|
|
return commands, nil
|
|
}
|
|
|
|
func init() {
|
|
Register(NewSchedulerDataPreprocess)
|
|
}
|