JCC-CSScheduler/executor/internal/task/scheduler_model_finetuning.go

105 lines
3.1 KiB
Go

package task
import (
"errors"
"gitlink.org.cn/cloudream/common/pkgs/logger"
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
exectsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor/task"
"gitlink.org.cn/cloudream/scheduler/common/utils"
"gitlink.org.cn/cloudream/scheduler/executor/internal/config"
"gitlink.org.cn/cloudream/scheduler/executor/internal/task/create_ecs"
"strings"
)
type SchedulerModelFinetuning struct {
*exectsk.SchedulerModelFinetuning
}
func NewSchedulerModelFinetuning(info *exectsk.SchedulerModelFinetuning) *SchedulerModelFinetuning {
return &SchedulerModelFinetuning{info}
}
func (t *SchedulerModelFinetuning) Execute(task *Task, ctx TaskContext) {
log := logger.WithType[SchedulerModelFinetuning]("Task")
log.Debugf("begin")
defer log.Debugf("end")
err := t.do(task, ctx)
if err != nil {
log.Error(err)
return
}
log.Info("ScheduleCreateECS...")
}
func (t *SchedulerModelFinetuning) do(task *Task, ctx TaskContext) error {
// 设置环境变量
commands := utils.ConvertEnvsToCommand(t.Envs)
if t.Type == schsdk.DataPreprocess {
_, err := getDataPreprocessCommands(t.Envs, t.InferencePlatform)
if err != nil {
task.SendStatus(exectsk.NewSchedulerModelFinetuningStatus(err))
return err
}
}
arr := utils.SplitCommands(t.CMD)
commands = append(commands, arr...)
factory := create_ecs.GetFactory(config.CloudName)
provider := factory.CreateProvider()
// 创建服务器
//instanceID, ecsIP, err := provider.CreateServer()
//if err != nil {
// task.SendStatus(exectsk.NewSchedulerModelFinetuningStatus(err.Error()))
// return err
//}
//logger.Info("create ECS success, instance id: " + instanceID + ", ip: " + ecsIP)
// 执行微调任务
_, err := provider.RunCommand(commands, "i-bp1ikwdsr5r9p5i9mggm", 2000)
// 执行结束后销毁服务器
//_, err2 := provider.DeleteInstance(instanceID)
//if err2 != nil {
// task.SendStatus(exectsk.NewScheduleCreateECSStatus("", "", err.Error()))
//}
if err != nil {
return err
}
return nil
}
func getDataPreprocessCommands(envs []schsdk.KVPair, inferencePlatform schsdk.InferencePlatform) ([]string, error) {
if inferencePlatform.PlatformName == "" {
return nil, errors.New("inferencePlatform.PlatformName is empty")
}
var commands []string
// 读取当前目录下的 data_preprocess.py 文件
fileContent := ""
fileContent = strings.ReplaceAll(fileContent, "@base_url@", inferencePlatform.ApiBaseUrl)
fileContent = strings.ReplaceAll(fileContent, "@api_key@", inferencePlatform.ApiKey)
fileContent = strings.ReplaceAll(fileContent, "@input_file_path@", "")
fileContent = strings.ReplaceAll(fileContent, "@output_file@", "")
fileContent = strings.ReplaceAll(fileContent, "@base_url@", inferencePlatform.ApiBaseUrl)
fileContent = strings.ReplaceAll(fileContent, "@base_url@", inferencePlatform.ApiBaseUrl)
commandContent := "echo -e '" + fileContent + "' > /opt/generate_data.py"
commands = append(commands, commandContent)
commandContent = "echo -e '" + fileContent + "' > /opt/generate_data.py"
commands = append(commands, commandContent)
return commands, nil
}
func init() {
Register(NewSchedulerModelFinetuning)
}