新增微调数据预处理功能
This commit is contained in:
parent
d69ea7dc48
commit
d6b47ff5fc
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"application": {
|
||||
"executorID": "2",
|
||||
"executorID": "5",
|
||||
"address": ":7895"
|
||||
},
|
||||
"logger": {
|
||||
|
@ -10,7 +10,7 @@
|
|||
"level": "debug"
|
||||
},
|
||||
"rabbitMQ": {
|
||||
"address": "127.0.0.1:5672",
|
||||
"address": "101.201.215.196:5672",
|
||||
"account": "cloudream",
|
||||
"password": "123456",
|
||||
"vhost": "/"
|
||||
|
@ -22,11 +22,10 @@
|
|||
"url": "http://localhost:7070"
|
||||
},
|
||||
"rclone": {
|
||||
"cds_rcloneID": "9471093",
|
||||
"cds_rcloneConfigID": "9471094"
|
||||
"cds_rcloneID": "9471093"
|
||||
},
|
||||
"reportIntervalSec": 10,
|
||||
"createECS-ali": {
|
||||
"createECS": {
|
||||
"cloud": "AliCloud",
|
||||
"auth_config": {
|
||||
"AccessKeyId": "LTAI5tJBqN3uRnzXeiiXTxkT",
|
||||
|
@ -60,14 +59,13 @@
|
|||
"Region": "cn-hangzhou"
|
||||
}
|
||||
},
|
||||
"createECS": {
|
||||
"createECS-sugon": {
|
||||
"cloud": "SugonCloud",
|
||||
"auth_config": {
|
||||
"user": "acgnnmfbwo",
|
||||
"password": "Pcl@2020",
|
||||
"orgid": "c8befbc1301665ba2dc5b2826f8dca1e",
|
||||
"clusterName": "华东一区【昆山】",
|
||||
"get_token_url": "https://ac.sugon.com/ac/openapi/v2/tokens"
|
||||
"clusterName": "华东一区【昆山】"
|
||||
},
|
||||
"ecs_config": {
|
||||
"description": "",
|
||||
|
@ -100,5 +98,18 @@
|
|||
],
|
||||
"acceleratorDesc": "4*异构加速卡1"
|
||||
}
|
||||
},
|
||||
"inferencePlatform": {
|
||||
"remark": "平台可选值:['xinference', 'ollama', 'oneapi', 'fastchat', 'openai', 'custom openai'],不同平台只需修改platform_name、api_base_url、api_key、api_proxy",
|
||||
"platformName": "xinference",
|
||||
"apiBaseUrl": "http://123.60.146.162:9997/v1",
|
||||
"apiKey": "EMPTY",
|
||||
"apiProxy": "",
|
||||
"llmModel": "glm4-chat01",
|
||||
"embedModel": "bge-large-zh-v1.5",
|
||||
"chunkMaxLength": "4096",
|
||||
"startChunkThreshold": "3000",
|
||||
"similarityThreshold": "0.5",
|
||||
"entriesPerFile": "2"
|
||||
}
|
||||
}
|
|
@ -5,12 +5,18 @@
|
|||
"outputDirectory": "log",
|
||||
"level": "debug"
|
||||
},
|
||||
"rabbitMQ": {
|
||||
"rabbitMQ2": {
|
||||
"address": "101.201.215.196:5672",
|
||||
"account": "cloudream",
|
||||
"password": "123456",
|
||||
"vhost": "/"
|
||||
},
|
||||
"rabbitMQ": {
|
||||
"address": "localhost:5672",
|
||||
"account": "cloudream",
|
||||
"password": "123456",
|
||||
"vhost": "/"
|
||||
},
|
||||
"db": {
|
||||
"address": "101.201.215.196:3306",
|
||||
"account": "pcm",
|
||||
|
@ -18,7 +24,10 @@
|
|||
"databaseName": "scheduler"
|
||||
},
|
||||
"cloudreamStorage": {
|
||||
"url": "http://120.46.183.86:7890"
|
||||
"url": "http://121.36.5.116:7890"
|
||||
},
|
||||
"reportTimeoutSecs": 20
|
||||
"reportTimeoutSecs": 10,
|
||||
"CDSRclone": {
|
||||
"cds_rcloneID": "1"
|
||||
}
|
||||
}
|
|
@ -39,8 +39,9 @@ type JobFiles struct {
|
|||
}
|
||||
|
||||
type PackageJobFile struct {
|
||||
PackageID cdssdk.PackageID `json:"packageID"`
|
||||
PackagePath string `json:"packagePath"` // Load之后的文件路径,一个相对路径,需要加上CDS数据库中的RemoteBase才是完整路径
|
||||
PackageID cdssdk.PackageID `json:"packageID"`
|
||||
PackagePath string `json:"packagePath"` // Load之后的文件路径,一个相对路径,需要加上CDS数据库中的RemoteBase才是完整路径
|
||||
ECSInstanceID schsdk.ECSInstanceID
|
||||
}
|
||||
|
||||
type ImageJobFile struct {
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
package task
|
||||
|
||||
import (
|
||||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||||
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
|
||||
schmod "gitlink.org.cn/cloudream/scheduler/common/models"
|
||||
)
|
||||
|
||||
type SchedulerDataPreprocess struct {
|
||||
TaskInfoBase
|
||||
CMD string `json:"cmd"`
|
||||
Envs []schsdk.KVPair `json:"envs"`
|
||||
UserID cdssdk.UserID `json:"userID"`
|
||||
ObjectStorage schmod.ObjectStorage `json:"objectStorage"`
|
||||
//ModelResource schmod.ModelResource `json:"modelResource"`
|
||||
}
|
||||
|
||||
type SchedulerDataPreprocessStatus struct {
|
||||
TaskStatusBase
|
||||
InstanceID string `json:"instanceID"`
|
||||
Error error `json:"error"`
|
||||
}
|
||||
|
||||
func NewSchedulerDataPreprocess(userID cdssdk.UserID, command string, envs []schsdk.KVPair, objectStorage schmod.ObjectStorage) *SchedulerDataPreprocess {
|
||||
return &SchedulerDataPreprocess{
|
||||
CMD: command,
|
||||
Envs: envs,
|
||||
UserID: userID,
|
||||
ObjectStorage: objectStorage,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSchedulerDataPreprocessStatus(instanceID string, err error) *SchedulerDataPreprocessStatus {
|
||||
return &SchedulerDataPreprocessStatus{
|
||||
InstanceID: instanceID,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register[*SchedulerDataPreprocess, *SchedulerDataPreprocessStatus]()
|
||||
}
|
|
@ -1,13 +1,19 @@
|
|||
package task
|
||||
|
||||
import schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||||
import (
|
||||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||||
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
|
||||
schmod "gitlink.org.cn/cloudream/scheduler/common/models"
|
||||
)
|
||||
|
||||
type SchedulerModelFinetuning struct {
|
||||
TaskInfoBase
|
||||
Type string `json:"type"`
|
||||
CMD string `json:"cmd"`
|
||||
Envs []schsdk.KVPair `json:"envs"`
|
||||
InferencePlatform schsdk.InferencePlatform `json:"inferencePlatform"`
|
||||
InstanceID string `json:"instanceID"`
|
||||
CMD string `json:"cmd"`
|
||||
Envs []schsdk.KVPair `json:"envs"`
|
||||
UserID cdssdk.UserID `json:"userID"`
|
||||
ObjectStorage schmod.ObjectStorage `json:"objectStorage"`
|
||||
ModelResource schmod.ModelResource `json:"modelResource"`
|
||||
}
|
||||
|
||||
type SchedulerModelFinetuningStatus struct {
|
||||
|
@ -15,11 +21,14 @@ type SchedulerModelFinetuningStatus struct {
|
|||
Error error `json:"error"`
|
||||
}
|
||||
|
||||
func NewSchedulerModelFinetuning(cmd string, envs []schsdk.KVPair, inferencePlatform schsdk.InferencePlatform) *SchedulerModelFinetuning {
|
||||
func NewSchedulerModelFinetuning(userID cdssdk.UserID, command string, objectStorage schmod.ObjectStorage, modelResource schmod.ModelResource, envs []schsdk.KVPair, instanceID string) *SchedulerModelFinetuning {
|
||||
return &SchedulerModelFinetuning{
|
||||
CMD: cmd,
|
||||
Envs: envs,
|
||||
InferencePlatform: inferencePlatform,
|
||||
CMD: command,
|
||||
Envs: envs,
|
||||
UserID: userID,
|
||||
ObjectStorage: objectStorage,
|
||||
ModelResource: modelResource,
|
||||
InstanceID: instanceID,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -182,7 +182,26 @@ func (s *DefaultPreScheduler) ScheduleJobSet(info *schsdk.JobSetInfo) (*jobmod.J
|
|||
|
||||
// 经过排序后,按顺序生成调度方案
|
||||
for _, job := range schJobs {
|
||||
if norJob, ok := job.Job.(*schsdk.NormalJobInfo); ok {
|
||||
|
||||
var fileInfo schsdk.JobFilesInfo
|
||||
isNormalType := false
|
||||
norJob, ok := job.Job.(*schsdk.NormalJobInfo)
|
||||
if ok {
|
||||
fileInfo = norJob.Files
|
||||
isNormalType = true
|
||||
}
|
||||
dpJob, ok := job.Job.(*schsdk.DataPreprocessJobInfo)
|
||||
if ok {
|
||||
fileInfo = dpJob.Files
|
||||
isNormalType = true
|
||||
}
|
||||
ftJob, ok := job.Job.(*schsdk.FinetuningJobInfo)
|
||||
if ok {
|
||||
fileInfo = ftJob.Files
|
||||
isNormalType = true
|
||||
}
|
||||
|
||||
if isNormalType {
|
||||
scheme, err := s.scheduleForNormalOrMultiJob(info, job, ccs, jobSetScheme.JobSchemes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -191,7 +210,7 @@ func (s *DefaultPreScheduler) ScheduleJobSet(info *schsdk.JobSetInfo) (*jobmod.J
|
|||
jobSetScheme.JobSchemes[job.Job.GetLocalJobID()] = *scheme
|
||||
|
||||
// 检查数据文件的配置项,生成上传文件方案
|
||||
s.fillNormarlJobLocalUploadScheme(norJob.Files, scheme.TargetCCID, filesUploadSchemes, ccs)
|
||||
s.fillNormarlJobLocalUploadScheme(fileInfo, scheme.TargetCCID, filesUploadSchemes, ccs)
|
||||
}
|
||||
|
||||
if mulJob, ok := job.Job.(*schsdk.MultiInstanceJobInfo); ok {
|
||||
|
@ -353,6 +372,12 @@ func (s *DefaultPreScheduler) scheduleForNormalOrMultiJob(jobSet *schsdk.JobSetI
|
|||
case *schsdk.NormalJobInfo:
|
||||
jobFiles = &runningJob.Files
|
||||
jobResource = &runningJob.Resources
|
||||
case *schsdk.DataPreprocessJobInfo:
|
||||
jobFiles = &runningJob.Files
|
||||
jobResource = &runningJob.Resources
|
||||
case *schsdk.FinetuningJobInfo:
|
||||
jobFiles = &runningJob.Files
|
||||
jobResource = &runningJob.Resources
|
||||
case *schsdk.MultiInstanceJobInfo:
|
||||
jobFiles = &runningJob.Files
|
||||
jobResource = &runningJob.Resources
|
||||
|
@ -408,6 +433,12 @@ func (s *DefaultPreScheduler) scheduleForSingleJob(job *schedulingJob, ccs map[s
|
|||
case *schsdk.NormalJobInfo:
|
||||
jobFiles = &runningJob.Files
|
||||
jobResource = &runningJob.Resources
|
||||
case *schsdk.DataPreprocessJobInfo:
|
||||
jobFiles = &runningJob.Files
|
||||
jobResource = &runningJob.Resources
|
||||
case *schsdk.FinetuningJobInfo:
|
||||
jobFiles = &runningJob.Files
|
||||
jobResource = &runningJob.Resources
|
||||
case *schsdk.MultiInstanceJobInfo:
|
||||
jobFiles = &runningJob.Files
|
||||
jobResource = &runningJob.Resources
|
||||
|
|
|
@ -4,9 +4,8 @@ import (
|
|||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||||
schmod "gitlink.org.cn/cloudream/scheduler/common/models"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -60,6 +59,7 @@ func ConvertEnvsToCommand(envs []schsdk.KVPair) []string {
|
|||
commandContent := "sed -i '/@key@/d' ~/.bashrc && echo 'export @key@=@value@' >> ~/.bashrc"
|
||||
commandContent = strings.Replace(commandContent, "@key@", envs[i].Key, -1)
|
||||
commandContent = strings.Replace(commandContent, "@value@", value, -1)
|
||||
logger.Info("env: " + commandContent)
|
||||
commands = append(commands, commandContent)
|
||||
}
|
||||
commandContent := "sudo source ~/.bashrc"
|
||||
|
@ -68,84 +68,67 @@ func ConvertEnvsToCommand(envs []schsdk.KVPair) []string {
|
|||
return commands
|
||||
}
|
||||
|
||||
func GetSSHClient(username string, password string, address string) *ssh.Client {
|
||||
// SSH连接配置
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: username,
|
||||
Auth: []ssh.AuthMethod{
|
||||
// 使用密码认证
|
||||
ssh.Password(password),
|
||||
// 或者使用私钥认证
|
||||
//publicKeyFile("C:\\Users\\27081\\.ssh\\id_rsa"),
|
||||
},
|
||||
// 安全性考虑,跳过主机密钥检查
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
func GetRcloneCommands(storage schmod.ObjectStorage, userID cdssdk.UserID, mountDir string) []string {
|
||||
var commands []string
|
||||
|
||||
// 连接SSH服务器
|
||||
client, err := ssh.Dial("tcp", address, sshConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to dial: %s", err)
|
||||
}
|
||||
//defer client.Close()
|
||||
return client
|
||||
// 下载Rclone
|
||||
//commandContent := "yum install -y fuse3"
|
||||
//commands = append(commands, commandContent)
|
||||
//commandContent = "cd /opt && downloadCode='import requests;response=requests.get(\"@url@\",stream=True);response.raise_for_status();boundary=response.headers.get(\"Content-Type\").split(\"boundary=\")[-1].encode();content=response.content;body=[part.split(b\"\\r\\n\\r\\n\",1)[1].rsplit(b\"\\r\\n--\",1)[0] for part in content.split(b\"--\"+boundary+b\"\\r\\n\") if b\"filename=\" in part][0];open(\"@filename@\",\"wb\").write(body);print(\"success\")' && rclone=\"$cds_url/object/download?userID=$userID&objectID=$rcloneID\" && python3 -c \"$(echo \"$downloadCode\" | sed -e \"s|@url@|$(printf '%s' \"$rclone\" | sed 's/[&/\\]/\\\\&/g')|\" -e \"s|@filename@|rclone|\")\" && chmod +x rclone"
|
||||
//commandContent = strings.Replace(commandContent, "$cds_url", schglb.CloudreamStorageConfig.URL, -1)
|
||||
//commandContent = strings.Replace(commandContent, "$rcloneID", schglb.CDSRclone.CDSRcloneID, -1)
|
||||
//commandContent = strings.Replace(commandContent, "$userID", strconv.FormatInt(int64(userID), 10), -1)
|
||||
//commands = append(commands, commandContent)
|
||||
//
|
||||
//// 生成Rclone配置文件
|
||||
//commandContent = "echo -e '[@tagName@] \n type = s3 \n provider = @provider@ \n access_key_id = @ak@ \n secret_access_key = @sk@ \n endpoint = @endpoint@ \n storage_class = STANDARD' > /opt/rclone.conf"
|
||||
tagName := storage.Bucket + "_" + storage.AK
|
||||
//commandContent = strings.Replace(commandContent, "@tagName@", tagName, -1)
|
||||
//commandContent = strings.Replace(commandContent, "@provider@", storage.Manufacturer, -1)
|
||||
//commandContent = strings.Replace(commandContent, "@ak@", storage.AK, -1)
|
||||
//commandContent = strings.Replace(commandContent, "@sk@", storage.SK, -1)
|
||||
//commandContent = strings.Replace(commandContent, "@endpoint@", storage.Endpoint, -1)
|
||||
//commands = append(commands, commandContent)
|
||||
|
||||
umountCommand := "umount -l /mnt/oss"
|
||||
commands = append(commands, umountCommand)
|
||||
// 挂载Rclone
|
||||
commandContent := "mkdir -p @mountDir@ && cd /opt && nohup ./rclone mount @tagName@:@bucket@ @mountDir@ --vfs-cache-mode full --vfs-read-wait 0 --vfs-read-chunk-size 128M --cache-db-purge -vv > rcloneMount.log 2>&1 &"
|
||||
commandContent = strings.Replace(commandContent, "@tagName@", tagName, -1)
|
||||
commandContent = strings.Replace(commandContent, "@bucket@", storage.Bucket, -1)
|
||||
commandContent = strings.Replace(commandContent, "@mountDir@", mountDir, -1)
|
||||
commands = append(commands, commandContent)
|
||||
|
||||
//commandContent = "cd /opt && wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && bash Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda3 && eval \"$($HOME/miniconda3/bin/conda shell.bash hook)\" && conda create -n myenv python=3.10 -y"
|
||||
//commands = append(commands, commandContent)
|
||||
|
||||
return commands
|
||||
}
|
||||
|
||||
// 如果使用私钥认证,可以使用这个函数加载私钥文件
|
||||
func publicKeyFile(file string) ssh.AuthMethod {
|
||||
buffer, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key, err := ssh.ParsePrivateKey(buffer)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return ssh.PublicKeys(key)
|
||||
func RemountRclone(storage schmod.ObjectStorage, userID cdssdk.UserID, mountDir string) string {
|
||||
umountCommand := "umount -l /mnt/oss"
|
||||
// 挂载Rclone
|
||||
commandContent := "mkdir -p @mountDir@ && cd /opt && nohup ./rclone mount @tagName@:@bucket@ @mountDir@ --vfs-cache-mode full --vfs-read-wait 0 --vfs-read-chunk-size 128M --cache-db-purge -vv > rcloneMount.log 2>&1 &"
|
||||
tagName := storage.Bucket + "_" + storage.AK
|
||||
commandContent = strings.Replace(commandContent, "@tagName@", tagName, -1)
|
||||
commandContent = strings.Replace(commandContent, "@bucket@", storage.Bucket, -1)
|
||||
commandContent = strings.Replace(commandContent, "@mountDir@", mountDir, -1)
|
||||
return umountCommand + " \n " + commandContent
|
||||
}
|
||||
|
||||
// 将嵌套的 map 处理成查询字符串的递归函数
|
||||
func ParseMapToStrings(config map[string]interface{}, prefix string) []string {
|
||||
var queryStrings []string
|
||||
for key, value := range config {
|
||||
fullKey := key
|
||||
if prefix != "" {
|
||||
fullKey = prefix + "." + key
|
||||
}
|
||||
func HandleCommand(startScript string) string {
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
queryStrings = append(queryStrings, fmt.Sprintf("%s=%s", fullKey, v))
|
||||
case int, float64:
|
||||
queryStrings = append(queryStrings, fmt.Sprintf("%s=%v", fullKey, v))
|
||||
case bool:
|
||||
queryStrings = append(queryStrings, fmt.Sprintf("%s=%t", fullKey, v))
|
||||
case map[string]interface{}:
|
||||
// 递归处理嵌套的 map
|
||||
queryStrings = append(queryStrings, ParseMapToStrings(v, fullKey)...)
|
||||
case []interface{}:
|
||||
// 处理数组
|
||||
for i, item := range v {
|
||||
// 数组的键以索引为后缀
|
||||
itemKey := fmt.Sprintf("%s[%d]", fullKey, i)
|
||||
switch item := item.(type) {
|
||||
case string:
|
||||
queryStrings = append(queryStrings, fmt.Sprintf("%s=%s", itemKey, item))
|
||||
case int, float64:
|
||||
queryStrings = append(queryStrings, fmt.Sprintf("%s=%v", itemKey, item))
|
||||
case bool:
|
||||
queryStrings = append(queryStrings, fmt.Sprintf("%s=%t", itemKey, item))
|
||||
case map[string]interface{}:
|
||||
// 递归处理嵌套的 map
|
||||
queryStrings = append(queryStrings, ParseMapToStrings(item, itemKey)...)
|
||||
default:
|
||||
fmt.Println("Unsupported config array item type:", item)
|
||||
}
|
||||
}
|
||||
default:
|
||||
fmt.Println("Unsupported config value type:", v)
|
||||
}
|
||||
}
|
||||
return queryStrings
|
||||
startScript = strings.Replace(startScript, "//", "/", -1)
|
||||
commandContent := "sudo sh @startScript@ > /opt/@startLog@.log"
|
||||
commandContent = strings.Replace(commandContent, "@startScript@", startScript, -1)
|
||||
arr := strings.Split(startScript, "/")
|
||||
commandContent = strings.Replace(commandContent, "@startLog@", arr[len(arr)-1], -1)
|
||||
|
||||
return commandContent
|
||||
}
|
||||
|
||||
func HandlePath(path string) string {
|
||||
path = strings.ReplaceAll(path, "\\", "/")
|
||||
path = strings.ReplaceAll(path, "//", "/")
|
||||
return path
|
||||
}
|
||||
|
|
|
@ -3,18 +3,13 @@ package task
|
|||
import (
|
||||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||||
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
|
||||
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
|
||||
schmod "gitlink.org.cn/cloudream/scheduler/common/models"
|
||||
"gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor"
|
||||
exectsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor/task"
|
||||
"gitlink.org.cn/cloudream/scheduler/common/utils"
|
||||
"gitlink.org.cn/cloudream/scheduler/executor/internal/config"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
//"gitlink.org.cn/cloudream/scheduler/executor/internal/config"
|
||||
"gitlink.org.cn/cloudream/scheduler/executor/internal/task/create_ecs"
|
||||
)
|
||||
|
@ -47,8 +42,8 @@ func (t *ScheduleCreateECS) do(task *Task, ctx TaskContext) error {
|
|||
// 创建云主机
|
||||
factory := create_ecs.GetFactory(config.CloudName)
|
||||
provider := factory.CreateProvider()
|
||||
instanceID, ecsIP, err := provider.CreateServer()
|
||||
//instanceID, ecsIP, err := "i-bp18see6gypratlt3nhp", "47.96.28.209", error(nil)
|
||||
//instanceID, ecsIP, err := provider.CreateServer()
|
||||
instanceID, ecsIP, err := "i-bp16imo8en907iy1oixd", "120.55.45.90", error(nil)
|
||||
if err != nil {
|
||||
task.SendStatus(exectsk.NewScheduleCreateECSStatus("", schsdk.CreateECS, err.Error()))
|
||||
return err
|
||||
|
@ -65,27 +60,28 @@ func (t *ScheduleCreateECS) do(task *Task, ctx TaskContext) error {
|
|||
}
|
||||
|
||||
// 设置环境变量
|
||||
t.Envs = append(t.Envs, schsdk.KVPair{Key: "MountDir", Value: schsdk.MountDir})
|
||||
commands := utils.ConvertEnvsToCommand(t.Envs)
|
||||
|
||||
// 获取挂载命令
|
||||
switch t.ObjectStorage.MountType {
|
||||
case schsdk.RcloneMount:
|
||||
rcloneCommands := getRcloneCommands(t.ModelResource, t.ObjectStorage, t.UserID)
|
||||
commands = append(commands, rcloneCommands...)
|
||||
case schsdk.Mounted:
|
||||
commandContent := "sudo sh @startScript@ > /opt/startup.log"
|
||||
commandContent = strings.Replace(commandContent, "@startScript@", t.ModelResource.StartShellPath, -1)
|
||||
commands = append(commands, commandContent)
|
||||
startScript := t.ModelResource.StartShellPath
|
||||
if t.ObjectStorage.MountType == schsdk.RcloneMount {
|
||||
startScript = schsdk.MountDir + "/" + t.ModelResource.StartShellPath
|
||||
// 获取Rclone挂载命令
|
||||
mountCommands := utils.GetRcloneCommands(t.ObjectStorage, t.UserID, schsdk.MountDir)
|
||||
commands = append(commands, mountCommands...)
|
||||
}
|
||||
// 获取启动命令
|
||||
commands = append(commands, utils.HandleCommand(startScript))
|
||||
|
||||
// 安装依赖包,用于获取GPU信息
|
||||
commandContent := getPipCommand()
|
||||
commands = append(commands, commandContent)
|
||||
//commandContent := getPipCommand()
|
||||
//commands = append(commands, commandContent)
|
||||
|
||||
// 获取用户输入的命令
|
||||
arr := utils.SplitCommands(t.Command)
|
||||
commands = append(commands, arr...)
|
||||
|
||||
// 执行命令
|
||||
//_, err = provider.RunCommand(commands, instanceID, 2000)
|
||||
//if err != nil {
|
||||
// logger.Error("run command error: " + err.Error())
|
||||
|
@ -131,13 +127,22 @@ func (t *ScheduleCreateECS) do(task *Task, ctx TaskContext) error {
|
|||
}
|
||||
task.SendStatus(exectsk.NewScheduleCreateECSStatus("", schsdk.PauseECS, ""))
|
||||
case schsdk.DestroyECS:
|
||||
_, err := provider.DeleteInstance(instanceID)
|
||||
if err != nil {
|
||||
task.SendStatus(exectsk.NewScheduleCreateECSStatus("", "", err.Error()))
|
||||
continue
|
||||
}
|
||||
logger.Info("destroy ecs")
|
||||
//_, err := provider.DeleteInstance(instanceID)
|
||||
//if err != nil {
|
||||
// task.SendStatus(exectsk.NewScheduleCreateECSStatus("", "", err.Error()))
|
||||
// continue
|
||||
//}
|
||||
task.SendStatus(exectsk.NewScheduleCreateECSStatus("", schsdk.DestroyECS, ""))
|
||||
break
|
||||
case schsdk.RestartServer:
|
||||
commandContent := utils.RemountRclone(t.ObjectStorage, t.UserID, schsdk.MountDir)
|
||||
info.Runtime.Command = info.Runtime.Command + "\n" + commandContent
|
||||
commandContent = schsdk.MountDir + "/" + t.ModelResource.StopShellPath
|
||||
info.Runtime.Command = info.Runtime.Command + "\n" + utils.HandleCommand(commandContent)
|
||||
commandContent = schsdk.MountDir + "/" + t.ModelResource.StartShellPath
|
||||
info.Runtime.Command = info.Runtime.Command + "\n" + utils.HandleCommand(commandContent)
|
||||
executeCommands(provider, instanceID, task, info.Runtime)
|
||||
case schsdk.OperateServer:
|
||||
executeCommands(provider, instanceID, task, info.Runtime)
|
||||
case schsdk.GPUMonitor:
|
||||
|
@ -168,12 +173,6 @@ func (t *ScheduleCreateECS) do(task *Task, ctx TaskContext) error {
|
|||
}
|
||||
}
|
||||
|
||||
func getRandomNum() string {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
randomFloat := rand.Float64() * 20
|
||||
return strconv.FormatFloat(randomFloat, 'f', 2, 64)
|
||||
}
|
||||
|
||||
func getPipCommand() string {
|
||||
commandContent := "python -m pip install --upgrade pip \n pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple \n pip install torch \n pip install gputil \n pip install psutil"
|
||||
return commandContent
|
||||
|
@ -195,46 +194,6 @@ func getGPUCommand(instanceID string) ([]string, string) {
|
|||
return commands, logFileName
|
||||
}
|
||||
|
||||
func getRcloneCommands(resource schmod.ModelResource, storage schmod.ObjectStorage, userID cdssdk.UserID) []string {
|
||||
var commands []string
|
||||
|
||||
// 下载Rclone
|
||||
commandContent := "yum install -y fuse3"
|
||||
commands = append(commands, commandContent)
|
||||
commandContent = "cd /opt && downloadCode='import requests;response=requests.get(\"@url@\",stream=True);response.raise_for_status();boundary=response.headers.get(\"Content-Type\").split(\"boundary=\")[-1].encode();content=response.content;body=[part.split(b\"\\r\\n\\r\\n\",1)[1].rsplit(b\"\\r\\n--\",1)[0] for part in content.split(b\"--\"+boundary+b\"\\r\\n\") if b\"filename=\" in part][0];open(\"@filename@\",\"wb\").write(body);print(\"success\")' && rclone=\"$cds_url/object/download?userID=$userID&objectID=$rcloneID\" && python3 -c \"$(echo \"$downloadCode\" | sed -e \"s|@url@|$(printf '%s' \"$rclone\" | sed 's/[&/\\]/\\\\&/g')|\" -e \"s|@filename@|rclone|\")\" && chmod +x rclone"
|
||||
commandContent = strings.Replace(commandContent, "$cds_url", schglb.CloudreamStorageConfig.URL, -1)
|
||||
commandContent = strings.Replace(commandContent, "$rcloneID", schglb.CDSRclone.CDSRcloneID, -1)
|
||||
commandContent = strings.Replace(commandContent, "$userID", strconv.FormatInt(int64(userID), 10), -1)
|
||||
commands = append(commands, commandContent)
|
||||
|
||||
// 生成Rclone配置文件
|
||||
commandContent = "echo -e '[@tagName@] \n type = s3 \n provider = @provider@ \n access_key_id = @ak@ \n secret_access_key = @sk@ \n endpoint = @endpoint@ \n storage_class = STANDARD' > /opt/rclone.conf"
|
||||
tagName := storage.Bucket + "_" + storage.AK
|
||||
commandContent = strings.Replace(commandContent, "@tagName@", tagName, -1)
|
||||
commandContent = strings.Replace(commandContent, "@provider@", storage.Manufacturer, -1)
|
||||
commandContent = strings.Replace(commandContent, "@ak@", storage.AK, -1)
|
||||
commandContent = strings.Replace(commandContent, "@sk@", storage.SK, -1)
|
||||
commandContent = strings.Replace(commandContent, "@endpoint@", storage.Endpoint, -1)
|
||||
commands = append(commands, commandContent)
|
||||
|
||||
// 挂载Rclone
|
||||
mountDir := "/mnt/oss"
|
||||
commandContent = "mkdir -p @mountDir@ && cd /opt && nohup ./rclone mount @tagName@:@bucket@ @mountDir@ --vfs-cache-mode full --vfs-read-wait 0 --vfs-read-chunk-size 128M --cache-db-purge -vv > rcloneMount.log 2>&1 &"
|
||||
commandContent = strings.Replace(commandContent, "@tagName@", tagName, -1)
|
||||
commandContent = strings.Replace(commandContent, "@bucket@", storage.Bucket, -1)
|
||||
commandContent = strings.Replace(commandContent, "@mountDir@", mountDir, -1)
|
||||
commands = append(commands, commandContent)
|
||||
|
||||
// 执行启动脚本
|
||||
startScript := mountDir + "/" + resource.StartShellPath
|
||||
startScript = strings.Replace(startScript, "//", "/", -1)
|
||||
commandContent = "sudo sh @startScript@ > /opt/startup.log"
|
||||
commandContent = strings.Replace(commandContent, "@startScript@", startScript, -1)
|
||||
commands = append(commands, commandContent)
|
||||
|
||||
return commands
|
||||
}
|
||||
|
||||
func executeCommands(provider create_ecs.CloudProvider, instanceID string, task *Task, runtime schsdk.JobRuntimeInfo) {
|
||||
commands := utils.ConvertEnvsToCommand(runtime.Envs)
|
||||
commands = append(commands, utils.SplitCommands(runtime.Command)...)
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
package task
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||||
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
|
||||
exectsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor/task"
|
||||
"gitlink.org.cn/cloudream/scheduler/common/utils"
|
||||
"gitlink.org.cn/cloudream/scheduler/executor/internal/config"
|
||||
"gitlink.org.cn/cloudream/scheduler/executor/internal/task/create_ecs"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SchedulerDataPreprocess struct {
|
||||
*exectsk.SchedulerDataPreprocess
|
||||
}
|
||||
|
||||
func NewSchedulerDataPreprocess(info *exectsk.SchedulerDataPreprocess) *SchedulerDataPreprocess {
|
||||
return &SchedulerDataPreprocess{info}
|
||||
}
|
||||
|
||||
func (t *SchedulerDataPreprocess) Execute(task *Task, ctx TaskContext) {
|
||||
log := logger.WithType[SchedulerDataPreprocess]("Task")
|
||||
log.Debugf("begin")
|
||||
defer log.Debugf("end")
|
||||
|
||||
err := t.do(task, ctx)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (t *SchedulerDataPreprocess) do(task *Task, ctx TaskContext) error {
|
||||
// 设置环境变量
|
||||
commands := utils.ConvertEnvsToCommand(t.Envs)
|
||||
|
||||
if t.ObjectStorage.MountType == schsdk.RcloneMount {
|
||||
// 获取Rclone挂载命令
|
||||
mountCommands := utils.GetRcloneCommands(t.ObjectStorage, t.UserID, schsdk.MountDir)
|
||||
commands = append(commands, mountCommands...)
|
||||
}
|
||||
|
||||
// 获取数据预处理命令
|
||||
datePreprocessCommands, err := getDataPreprocessCommands(t.Envs, schglb.InferencePlatform)
|
||||
if err != nil {
|
||||
task.SendStatus(exectsk.NewSchedulerDataPreprocessStatus("", err))
|
||||
return err
|
||||
}
|
||||
commands = append(commands, datePreprocessCommands...)
|
||||
|
||||
// 添加用户自定义命令
|
||||
arr := utils.SplitCommands(t.CMD)
|
||||
commands = append(commands, arr...)
|
||||
|
||||
factory := create_ecs.GetFactory(config.CloudName)
|
||||
provider := factory.CreateProvider()
|
||||
|
||||
// 创建服务器
|
||||
//instanceID, ecsIP, err := provider.CreateServer()
|
||||
instanceID, ecsIP, err := "i-bp16imo8en907iy1oixd", "120.55.45.90", error(nil)
|
||||
if err != nil {
|
||||
task.SendStatus(exectsk.NewSchedulerDataPreprocessStatus("", err))
|
||||
return err
|
||||
}
|
||||
logger.Info("run ECS success, instance id: " + instanceID + ", ip: " + ecsIP)
|
||||
|
||||
// 执行数据预处理命令
|
||||
_, err = provider.RunCommand(commands, instanceID, 2000)
|
||||
if err != nil {
|
||||
task.SendStatus(exectsk.NewSchedulerDataPreprocessStatus("", err))
|
||||
return err
|
||||
}
|
||||
|
||||
task.SendStatus(exectsk.NewSchedulerDataPreprocessStatus(instanceID, err))
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDataPreprocessCommands(envs []schsdk.KVPair, inferencePlatform schsdk.InferencePlatform) ([]string, error) {
|
||||
|
||||
if inferencePlatform.PlatformName == "" {
|
||||
return nil, errors.New("inferencePlatform.PlatformName is empty")
|
||||
}
|
||||
|
||||
var commands []string
|
||||
|
||||
// 获取当前工作目录
|
||||
currentDir, err := filepath.Abs(".")
|
||||
if err != nil {
|
||||
fmt.Println("Error getting current directory:", err)
|
||||
}
|
||||
|
||||
parentDir := filepath.Dir(currentDir)
|
||||
|
||||
// 指定要读取的文件名
|
||||
fileName := "example.txt" // 替换为你要读取的文件名
|
||||
|
||||
// 构造完整路径
|
||||
filePath := filepath.Join(parentDir, fileName)
|
||||
|
||||
// 读取文件
|
||||
data, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading file:", err)
|
||||
}
|
||||
|
||||
// 输出文件内容
|
||||
fmt.Println("File content:")
|
||||
fmt.Println(string(data))
|
||||
content := string(data)
|
||||
|
||||
// 读取文件
|
||||
//content, err := ioutil.ReadFile("D:\\Work\\Codes\\new\\workspace\\workspace\\scheduler\\common\\assets\\scripts\\data_preprocess.py")
|
||||
//if err != nil {
|
||||
// logger.Error(err)
|
||||
// return nil, err
|
||||
//}
|
||||
|
||||
fileContent := string(content)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@base_url@", inferencePlatform.ApiBaseUrl)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@api_key@", inferencePlatform.ApiKey)
|
||||
inputPath := schsdk.MountDir + "/" + envs[0].Value
|
||||
inputPath = utils.HandlePath(inputPath)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@input_file_path@", inputPath)
|
||||
outputPath := schsdk.MountDir + "/" + envs[1].Value
|
||||
outputPath = utils.HandlePath(outputPath)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@output_file_path@", outputPath)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@chunk_max_length@", inferencePlatform.ChunkMaxLength)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@start_chunk_threshold@", inferencePlatform.StartChunkThreshold)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@similarity_threshold@", inferencePlatform.SimilarityThreshold)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@entries_per_file@", inferencePlatform.EntriesPerFile)
|
||||
|
||||
commandContent := "echo -e '" + fileContent + "' > /opt/generate_data.py"
|
||||
commands = append(commands, commandContent)
|
||||
|
||||
commandContent = "source $HOME/miniconda3/bin/activate myenv && python -m pip install --upgrade pip && pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple && pip install PyMuPDF==1.19.0 scipy openai numpy backoff"
|
||||
commands = append(commands, commandContent)
|
||||
commandContent = "source $HOME/miniconda3/bin/activate myenv && python /opt/generate_data.py"
|
||||
commands = append(commands, commandContent)
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register(NewSchedulerDataPreprocess)
|
||||
}
|
|
@ -1,14 +1,12 @@
|
|||
package task
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||||
exectsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor/task"
|
||||
"gitlink.org.cn/cloudream/scheduler/common/utils"
|
||||
"gitlink.org.cn/cloudream/scheduler/executor/internal/config"
|
||||
"gitlink.org.cn/cloudream/scheduler/executor/internal/task/create_ecs"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SchedulerModelFinetuning struct {
|
||||
|
@ -34,71 +32,64 @@ func (t *SchedulerModelFinetuning) Execute(task *Task, ctx TaskContext) {
|
|||
}
|
||||
|
||||
func (t *SchedulerModelFinetuning) do(task *Task, ctx TaskContext) error {
|
||||
|
||||
// t.Envs添加新值
|
||||
t.Envs = append(t.Envs, schsdk.KVPair{Key: "MountDir", Value: schsdk.MountDir})
|
||||
// 设置环境变量
|
||||
commands := utils.ConvertEnvsToCommand(t.Envs)
|
||||
|
||||
if t.Type == schsdk.DataPreprocess {
|
||||
_, err := getDataPreprocessCommands(t.Envs, t.InferencePlatform)
|
||||
if err != nil {
|
||||
task.SendStatus(exectsk.NewSchedulerModelFinetuningStatus(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
arr := utils.SplitCommands(t.CMD)
|
||||
commands = append(commands, arr...)
|
||||
|
||||
factory := create_ecs.GetFactory(config.CloudName)
|
||||
provider := factory.CreateProvider()
|
||||
|
||||
// 创建服务器
|
||||
//instanceID, ecsIP, err := provider.CreateServer()
|
||||
//if err != nil {
|
||||
// task.SendStatus(exectsk.NewSchedulerModelFinetuningStatus(err.Error()))
|
||||
// return err
|
||||
instanceID := t.InstanceID
|
||||
// 如果没有指定实例ID,则创建一个
|
||||
//if t.InstanceID == "" {
|
||||
// // 创建服务器
|
||||
// instID, ecsIP, err := provider.CreateServer()
|
||||
// if err != nil {
|
||||
// task.SendStatus(exectsk.NewSchedulerModelFinetuningStatus(err))
|
||||
// return err
|
||||
// }
|
||||
// instanceID = instID
|
||||
// logger.Info("create ECS success, instance id: " + instanceID + ", ip: " + ecsIP)
|
||||
//
|
||||
// if t.ObjectStorage.MountType == schsdk.RcloneMount {
|
||||
// // 获取Rclone挂载命令
|
||||
// mountCommands := utils.GetRcloneCommands(t.ObjectStorage, t.UserID, schsdk.MountDir)
|
||||
// commands = append(commands, mountCommands...)
|
||||
// }
|
||||
//}
|
||||
//logger.Info("create ECS success, instance id: " + instanceID + ", ip: " + ecsIP)
|
||||
|
||||
mountCommands := utils.GetRcloneCommands(t.ObjectStorage, t.UserID, schsdk.MountDir)
|
||||
commands = append(commands, mountCommands...)
|
||||
// 获取微调脚本执行命令
|
||||
startScript := t.ModelResource.FinetuningShellPath
|
||||
if t.ObjectStorage.MountType == schsdk.RcloneMount {
|
||||
startScript = schsdk.MountDir + "/" + t.ModelResource.FinetuningShellPath
|
||||
}
|
||||
// 获取启动命令
|
||||
commands = append(commands, utils.HandleCommand(startScript))
|
||||
|
||||
// 执行微调任务
|
||||
_, err := provider.RunCommand(commands, "i-bp1ikwdsr5r9p5i9mggm", 2000)
|
||||
_, err := provider.RunCommand(commands, instanceID, 2000)
|
||||
// 执行结束后销毁服务器
|
||||
//_, err2 := provider.DeleteInstance(instanceID)
|
||||
//if err2 != nil {
|
||||
// task.SendStatus(exectsk.NewScheduleCreateECSStatus("", "", err.Error()))
|
||||
// task.SendStatus(exectsk.NewSchedulerModelFinetuningStatus(err))
|
||||
// return err2
|
||||
//}
|
||||
if err != nil {
|
||||
task.SendStatus(exectsk.NewSchedulerModelFinetuningStatus(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task.SendStatus(exectsk.NewSchedulerModelFinetuningStatus(nil))
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDataPreprocessCommands(envs []schsdk.KVPair, inferencePlatform schsdk.InferencePlatform) ([]string, error) {
|
||||
|
||||
if inferencePlatform.PlatformName == "" {
|
||||
return nil, errors.New("inferencePlatform.PlatformName is empty")
|
||||
}
|
||||
|
||||
var commands []string
|
||||
|
||||
// 读取当前目录下的 data_preprocess.py 文件
|
||||
fileContent := ""
|
||||
|
||||
fileContent = strings.ReplaceAll(fileContent, "@base_url@", inferencePlatform.ApiBaseUrl)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@api_key@", inferencePlatform.ApiKey)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@input_file_path@", "")
|
||||
fileContent = strings.ReplaceAll(fileContent, "@output_file@", "")
|
||||
fileContent = strings.ReplaceAll(fileContent, "@base_url@", inferencePlatform.ApiBaseUrl)
|
||||
fileContent = strings.ReplaceAll(fileContent, "@base_url@", inferencePlatform.ApiBaseUrl)
|
||||
|
||||
commandContent := "echo -e '" + fileContent + "' > /opt/generate_data.py"
|
||||
commands = append(commands, commandContent)
|
||||
commandContent = "echo -e '" + fileContent + "' > /opt/generate_data.py"
|
||||
commands = append(commands, commandContent)
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register(NewSchedulerModelFinetuning)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
import json
|
||||
|
||||
import re
|
||||
from scipy.spatial.distance import cosine
|
||||
import fitz
|
||||
import shutil
|
||||
import os
|
||||
import logging
|
||||
import backoff
|
||||
from openai import OpenAI
|
||||
import numpy as np
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
client = OpenAI(base_url="@base_url@", api_key="@api_key@")
|
||||
|
||||
# 设置输入和输出路径
|
||||
input_file_path = "@input_file_path@"
|
||||
output_dir = "saveChunk/"
|
||||
output_file_path = "@output_file_path@"
|
||||
|
||||
# 设置每个文本块的最大分词数量
|
||||
chunk_max_length = @chunk_max_length@
|
||||
# 分块阈值
|
||||
start_chunk_threshold = @start_chunk_threshold@
|
||||
# 相似度阈值
|
||||
similarity_threshold = @similarity_threshold@
|
||||
# 数据分析次数
|
||||
entries_per_file = @entries_per_file@
|
||||
|
||||
def read_file(file_path: str) -> str:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
@backoff.on_exception(backoff.expo, Exception, max_tries=3)
|
||||
def generate_single_entry(text: str):
|
||||
prompt = f"""
|
||||
基于以下文本,生成1个用于指令数据集的高质量条目。条目应该直接关联到给定的文本内容,提出相关的问题或任务。
|
||||
请确保生成多样化的指令类型,例如:
|
||||
- 分析类:"分析..."
|
||||
- 比较类:"比较..."
|
||||
- 解释类:"解释..."
|
||||
- 评价类:"评价..."
|
||||
- 问答类:"为什么..."
|
||||
|
||||
文本内容:
|
||||
{text}
|
||||
|
||||
请以下面的格式生成条目,确保所有字段都有适当的内容:
|
||||
{{
|
||||
"instruction": "使用上述多样化的指令类型之一,提出一个具体的、与文本相关的问题或任务",
|
||||
"input": "如果需要额外的上下文信息,请在这里提供,否则跟上面的instruction保持一致",
|
||||
"output": "对instruction的详细回答或任务的完成结果"
|
||||
}}
|
||||
确保所有生成的内容都与给定的文本直接相关,生成的是完整、有效的JSON格式,并且内容高质量、准确、详细,当有多个json时,用空行分隔。
|
||||
"""
|
||||
|
||||
try:
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="glm4-chat01",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个指令生成专家"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
response = resp.choices[0].message.content
|
||||
|
||||
result = ""
|
||||
jsonStrList = response.split("}")
|
||||
for item in jsonStrList:
|
||||
if item == "":
|
||||
continue
|
||||
json_str = item + "}"
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
# 检查必要的键是否存在
|
||||
required_keys = {"instruction", "input", "output"}
|
||||
if not required_keys.issubset(data.keys()):
|
||||
logger.error(f"生成的条目缺少一些关键字段,请检查:{required_keys - data.keys()}")
|
||||
continue
|
||||
result = result + item + "},"
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析JSON字符串时发生错误: {str(e)}, output: {json_str}")
|
||||
|
||||
logger.info(f"output: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成条目时发生错误: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def generate_dataset(folder_path: str, output_file_path, entries_per_file: int = 2):
|
||||
dataset = []
|
||||
result = "[ \\n"
|
||||
for filename in os.listdir(folder_path):
|
||||
if filename.endswith(".txt"):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
logger.info(f"正在处理文件: {filename}")
|
||||
text = read_file(file_path)
|
||||
for j in range(entries_per_file):
|
||||
logger.info(f" 生成第 {j + 1}/{entries_per_file} 个条目")
|
||||
entry = generate_single_entry(text)
|
||||
if entry == None:
|
||||
logger.error("生成条目时发生错误,跳过当前条目")
|
||||
continue
|
||||
result = result + entry
|
||||
|
||||
result = result[:-1] + "\\n]"
|
||||
# 将结果写入到文件中
|
||||
with open(output_file_path, "w") as f:
|
||||
f.write(result)
|
||||
return result
|
||||
|
||||
def get_sentence_embedding(sentence, client):
|
||||
"""
|
||||
获取句子的嵌入表示
|
||||
|
||||
参数:
|
||||
sentence (str): 输入的句子
|
||||
client: OpenAI 客户端实例
|
||||
|
||||
返回:
|
||||
numpy.ndarray: 句子的嵌入向量
|
||||
"""
|
||||
# 使用 Xinference 嵌入 API 获取句子嵌入
|
||||
global embedding
|
||||
try:
|
||||
response = client.embeddings.create(model="bge-large-zh-v1.5", input=sentence)
|
||||
embedding = response.data[0].embedding
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return np.array(embedding)
|
||||
|
||||
def split_text_by_semantic(text, chunk_max_length, similarity_threshold=0.5):
|
||||
"""
|
||||
基于语义相似度对文本进行分块
|
||||
|
||||
参数:
|
||||
text (str): 输入的长文本
|
||||
chunk_max_length (int): 每个文本块的最大长度(以token为单位)
|
||||
similarity_threshold (float): 语义相似度阈值,默认为0.5
|
||||
|
||||
返回:
|
||||
list: 分割后的文本块列表
|
||||
"""
|
||||
chunks = []
|
||||
|
||||
# 按句子分割文本(使用常见的中文标点符号)
|
||||
# sentences = re.split(r"(。|!|?|;)", text)
|
||||
sentences = re.split(r"\\n\s*\\n", text)
|
||||
# 重新组合句子和标点
|
||||
sentences = [s + p for s, p in zip(sentences[::2], sentences[1::2]) if s]
|
||||
|
||||
current_chunk = sentences[0]
|
||||
# 获取当前 chunk 的嵌入表示
|
||||
current_embedding = get_sentence_embedding(current_chunk, client)
|
||||
|
||||
for sentence in sentences[1:]:
|
||||
# 过滤掉空数据
|
||||
if not sentence.strip():
|
||||
continue
|
||||
# 删除文本中的空行
|
||||
sentence = re.sub(r"\\n\s*\\n", "", sentence)
|
||||
# 获取当前句子的嵌入表示
|
||||
sentence_embedding = get_sentence_embedding(sentence, client)
|
||||
# 计算当前 chunk 和当前句子的余弦相似度
|
||||
similarity = 1 - cosine(current_embedding, sentence_embedding)
|
||||
logger.info(f"similarity: {similarity}, and sentence: {sentence}")
|
||||
|
||||
# 如果相似度高于阈值且合并后不超过最大长度,则合并
|
||||
if similarity > similarity_threshold and len(current_chunk + sentence) <= chunk_max_length:
|
||||
current_chunk += sentence
|
||||
# 更新当前 chunk 的嵌入表示
|
||||
current_embedding = (current_embedding + sentence_embedding) / 2
|
||||
else:
|
||||
# 否则,保存当前 chunk 并开始新的 chunk
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = sentence
|
||||
current_embedding = sentence_embedding
|
||||
|
||||
# 添加最后一个 chunk
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def read_text_file(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
def save_chunks_to_files(chunks, output_dir):
|
||||
"""
|
||||
将分割后的文本块保存到文件
|
||||
|
||||
参数:
|
||||
chunks (list): 文本块列表
|
||||
output_dir (str): 输出目录路径
|
||||
"""
|
||||
# 如果输出目录不存在,则创建
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 将每个文本块保存为单独的文件
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_file_path = os.path.join(output_dir, f"chunk_{i + 1}.txt")
|
||||
with open(chunk_file_path, "w", encoding="utf-8") as file:
|
||||
file.write(chunk)
|
||||
logger.info(f"已保存第 {i + 1} 个文本块到 {chunk_file_path}")
|
||||
|
||||
|
||||
def pdf_to_text(pdf_path, txt_path):
|
||||
pdf_document = fitz.open(pdf_path)
|
||||
with open(txt_path, "w", encoding="utf-8") as text_file:
|
||||
for page_num in range(len(pdf_document)):
|
||||
page = pdf_document.load_page(page_num)
|
||||
text = page.get_text()
|
||||
text_file.write(text)
|
||||
pdf_document.close()
|
||||
|
||||
|
||||
def clean_dir(directory):
|
||||
try:
|
||||
shutil.rmtree(directory)
|
||||
logger.info(f"成功删除文件夹及其内容: {directory}")
|
||||
except FileNotFoundError:
|
||||
logger.info(f"文件夹 {directory} 不存在")
|
||||
except PermissionError:
|
||||
logger.info(f"没有权限删除文件夹 {directory}")
|
||||
except Exception as e:
|
||||
logger.info(f"删除失败: {e}")
|
||||
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.info(f"成功创建文件夹: {directory}")
|
||||
except OSError as e:
|
||||
logger.info(f"创建文件夹失败: {e}")
|
||||
|
||||
|
||||
def get_file_type(file_path):
|
||||
_, ext = os.path.splitext(file_path)
|
||||
ext = ext.lower()
|
||||
|
||||
if ext in [".txt"]:
|
||||
return "txt"
|
||||
elif ext in [".pdf"]:
|
||||
return "pdf"
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
clean_dir(output_dir)
|
||||
|
||||
for root, dirs, files in os.walk(input_file_path):
|
||||
for file in files:
|
||||
input_file = os.path.join(root, file)
|
||||
if get_file_type(input_file) == "pdf":
|
||||
pdf_to_text(input_file, input_file + ".txt")
|
||||
input_file = input_file + ".txt"
|
||||
elif get_file_type(input_file) == "unknown":
|
||||
raise ValueError("输入文件类型不正确,请输入文本文件或PDF文件")
|
||||
# 读取长文本
|
||||
long_text = read_text_file(input_file)
|
||||
text_chunks = [long_text]
|
||||
|
||||
if len(long_text) > start_chunk_threshold:
|
||||
text_chunks = split_text_by_semantic(long_text, chunk_max_length, similarity_threshold)
|
||||
|
||||
save_chunks_to_files(text_chunks, output_dir)
|
||||
|
||||
logger.info("开始生成数据集")
|
||||
output_file_path = output_file_path + "/"
|
||||
os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
|
||||
output_file = os.path.join(output_file_path, file+".json")
|
||||
dataset = generate_dataset(output_dir, output_file, entries_per_file)
|
||||
logger.info(f"数据集已生成并保存到 {output_file}")
|
||||
|
|
@ -3,7 +3,6 @@ package jobmgr
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||||
"sync"
|
||||
|
||||
"gitlink.org.cn/cloudream/common/pkgs/future"
|
||||
|
@ -70,7 +69,6 @@ func (s *EventSet) Wait(ctx context.Context, cond EventWaitCondition) (Event, bo
|
|||
future: fut,
|
||||
}
|
||||
s.waiters = append(s.waiters, waiter)
|
||||
logger.Info("append waiter: %p", &waiter)
|
||||
|
||||
s.lock.Unlock()
|
||||
|
||||
|
|
|
@ -8,10 +8,11 @@ import (
|
|||
|
||||
type DataReturnJob struct {
|
||||
Info schsdk.DataReturnJobInfo
|
||||
TargetJobID schsdk.JobID // 目标任务的ID
|
||||
TargetJobCCID schsdk.CCID // 目标任务所在计算中心的ID
|
||||
TargetJobOutputPath string // 目标任务的结果输出路径,相对路径
|
||||
DataReturnPackageID cdssdk.PackageID // 回源之后得到的PackageID
|
||||
TargetJobID schsdk.JobID // 目标任务的ID
|
||||
TargetJobCCID schsdk.CCID // 目标任务所在计算中心的ID
|
||||
TargetJobOutputPath string // 目标任务的结果输出路径,相对路径
|
||||
DataReturnPackageID cdssdk.PackageID // 回源之后得到的PackageID
|
||||
ECSInstanceID schsdk.ECSInstanceID // ECS实例ID,在数据预处理和模型微调需要复用同一台机器时使用
|
||||
}
|
||||
|
||||
func NewDataReturnJob(info schsdk.DataReturnJobInfo) *DataReturnJob {
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
package job
|
||||
|
||||
import (
|
||||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||||
jobmod "gitlink.org.cn/cloudream/scheduler/common/models/job"
|
||||
)
|
||||
|
||||
type FinetuningJob struct {
|
||||
Info schsdk.FinetuningJobInfo // 提交任务时提供的任务描述信息
|
||||
Files jobmod.JobFiles // 任务需要的文件
|
||||
TargetCCID schsdk.CCID // 将要运行此任务的算力中心ID
|
||||
OutputPath string // 程序结果输出路径,一个相对路径,需要加上CDS数据库中记录的RemoteBase才是完整路径
|
||||
}
|
||||
|
||||
func NewFinetuningJob(info schsdk.FinetuningJobInfo) *FinetuningJob {
|
||||
return &FinetuningJob{
|
||||
Info: info,
|
||||
}
|
||||
}
|
||||
|
||||
func (j *FinetuningJob) GetInfo() schsdk.JobInfo {
|
||||
return &j.Info
|
||||
}
|
||||
|
||||
func (j *FinetuningJob) Dump() jobmod.JobBodyDump {
|
||||
return &jobmod.FinetuningJobDump{
|
||||
Files: j.Files,
|
||||
TargetCCID: j.TargetCCID,
|
||||
}
|
||||
}
|
|
@ -6,10 +6,12 @@ import (
|
|||
)
|
||||
|
||||
type NormalJob struct {
|
||||
Info schsdk.NormalJobInfo // 提交任务时提供的任务描述信息
|
||||
Files jobmod.JobFiles // 任务需要的文件
|
||||
TargetCCID schsdk.CCID // 将要运行此任务的算力中心ID
|
||||
OutputPath string // 程序结果输出路径,一个相对路径,需要加上CDS数据库中记录的RemoteBase才是完整路径
|
||||
Info schsdk.NormalJobInfo // 提交任务时提供的任务描述信息
|
||||
Files jobmod.JobFiles // 任务需要的文件
|
||||
TargetCCID schsdk.CCID // 将要运行此任务的算力中心ID
|
||||
OutputPath string // 程序结果输出路径,一个相对路径,需要加上CDS数据库中记录的RemoteBase才是完整路径
|
||||
SubType string // 用于区分普通任务下的子类型
|
||||
ECSInstanceID schsdk.ECSInstanceID // ECS实例ID,在数据预处理和模型微调需要复用同一台机器时使用
|
||||
}
|
||||
|
||||
func NewNormalJob(info schsdk.NormalJobInfo) *NormalJob {
|
||||
|
|
|
@ -79,11 +79,6 @@ func (s *Adjusting) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error {
|
|||
jobFiles = &runningJob.Files
|
||||
runningJob.TargetCCID = s.scheme.TargetCCID
|
||||
runningJob.OutputPath = outputPath
|
||||
case *job.FinetuningJob:
|
||||
jobFilesInfo = runningJob.Info.Files
|
||||
jobFiles = &runningJob.Files
|
||||
runningJob.TargetCCID = s.scheme.TargetCCID
|
||||
runningJob.OutputPath = outputPath
|
||||
case *job.MultiInstanceJob:
|
||||
jobFilesInfo = runningJob.Info.Files
|
||||
jobFiles = &runningJob.Files
|
||||
|
|
|
@ -56,45 +56,73 @@ func (s *JobExecuting) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error {
|
|||
|
||||
switch runningJob := jo.Body.(type) {
|
||||
case *job.NormalJob:
|
||||
pcmImgInfo, err := rtx.Mgr.DB.PCMImage().GetByImageIDAndCCID(rtx.Mgr.DB.SQLCtx(), runningJob.Files.Image.ImageID, runningJob.TargetCCID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting pcm image info: %w", err)
|
||||
}
|
||||
ress, err := rtx.Mgr.DB.CCResource().GetByCCID(rtx.Mgr.DB.SQLCtx(), runningJob.TargetCCID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting computing center resource: %w", err)
|
||||
}
|
||||
if len(ress) == 0 {
|
||||
return fmt.Errorf("no resource found at computing center %v", runningJob.TargetCCID)
|
||||
switch runningJob.SubType {
|
||||
case schsdk.JobTypeNormal: // 普通任务
|
||||
pcmImgInfo, err := rtx.Mgr.DB.PCMImage().GetByImageIDAndCCID(rtx.Mgr.DB.SQLCtx(), runningJob.Files.Image.ImageID, runningJob.TargetCCID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting pcm image info: %w", err)
|
||||
}
|
||||
ress, err := rtx.Mgr.DB.CCResource().GetByCCID(rtx.Mgr.DB.SQLCtx(), runningJob.TargetCCID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting computing center resource: %w", err)
|
||||
}
|
||||
if len(ress) == 0 {
|
||||
return fmt.Errorf("no resource found at computing center %v", runningJob.TargetCCID)
|
||||
}
|
||||
|
||||
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting storage info: %w", err)
|
||||
}
|
||||
dataSetPath := getDataSetPathByID(runningJob.Files.Dataset.PackageID)
|
||||
cmd, envs := getRuntimeCommand(runningJob.Info.Runtime, dataSetPath, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
|
||||
err = s.submitNormalTask(rtx, cmd, envs, *ccInfo, pcmImgInfo, ress[0].PCMResourceID)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
|
||||
case schsdk.JobTypeDataPreprocess: // 数据预处理
|
||||
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting storage info: %w", err)
|
||||
}
|
||||
dataSetPath := getDataSetPathByID(runningJob.Files.Dataset.PackageID)
|
||||
cmd, envs := getRuntimeCommand(runningJob.Info.Runtime, dataSetPath, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
|
||||
instID, err := s.submitDataPreprocessTask(rtx, cmd, envs, *ccInfo, getStg.StorageID, userID)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
runningJob.ECSInstanceID = schsdk.ECSInstanceID(instID)
|
||||
|
||||
case schsdk.JobTypeFinetuning: // 模型微调
|
||||
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting storage info: %w", err)
|
||||
}
|
||||
dataSetPath := getDataSetPathByID(runningJob.Files.Dataset.PackageID)
|
||||
// 将整理的数据集提交到OSS
|
||||
if runningJob.Files.Dataset.ECSInstanceID != "" {
|
||||
logger.Infof("instance id: %v", runningJob.ECSInstanceID)
|
||||
dataSetPath, err = loadDatasetPackage(userID, runningJob.Files.Dataset.PackageID, getStg.StorageID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading dataset package: %w", err)
|
||||
}
|
||||
}
|
||||
cmd, envs := getRuntimeCommand(runningJob.Info.Runtime, dataSetPath, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
|
||||
err = s.submitFinetuningTask(userID, rtx, cmd, envs, *ccInfo, getStg.StorageID, runningJob)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
case *job.InstanceJob: // 推理任务
|
||||
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting storage info: %w", err)
|
||||
}
|
||||
cmd, envs := getRuntimeCommand(runningJob.Info.Runtime, runningJob.Files.Dataset.PackageID, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
|
||||
err = s.submitNormalTask(rtx, cmd, envs, *ccInfo, pcmImgInfo, ress[0].PCMResourceID)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
|
||||
case *job.FinetuningJob:
|
||||
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting storage info: %w", err)
|
||||
}
|
||||
cmd, envs := getRuntimeCommand(runningJob.Info.Runtime, runningJob.Files.Dataset.PackageID, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
|
||||
err = s.submitFinetuningTask(rtx, cmd, envs, *ccInfo)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
|
||||
case *job.InstanceJob:
|
||||
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting storage info: %w", err)
|
||||
}
|
||||
_, envs := getRuntimeCommand(runningJob.Info.Runtime, runningJob.Files.Dataset.PackageID, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
|
||||
dataSetPath := getDataSetPathByID(runningJob.Files.Dataset.PackageID)
|
||||
_, envs := getRuntimeCommand(runningJob.Info.Runtime, dataSetPath, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
|
||||
err = s.submitInstanceTask(rtx, jo, runningJob, *ccInfo, getStg.StorageID, userID, envs)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
|
@ -105,6 +133,31 @@ func (s *JobExecuting) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func getDataSetPathByID(packageID cdssdk.PackageID) string {
|
||||
// TODO 临时使用,这个路径应该来自于CDS
|
||||
dataSetPath := filepath.Join("packages", "1", fmt.Sprintf("%v", packageID))
|
||||
return dataSetPath
|
||||
}
|
||||
|
||||
func loadDatasetPackage(userID cdssdk.UserID, packageID cdssdk.PackageID, storageID cdssdk.StorageID) (string, error) {
|
||||
stgCli, err := schglb.CloudreamStoragePool.Acquire()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer schglb.CloudreamStoragePool.Release(stgCli)
|
||||
|
||||
loadPackageResp, err := stgCli.StorageLoadPackage(cdssdk.StorageLoadPackageReq{
|
||||
UserID: userID,
|
||||
PackageID: packageID,
|
||||
StorageID: storageID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
logger.Info("load pacakge path: " + loadPackageResp.FullPath)
|
||||
return loadPackageResp.FullPath, nil
|
||||
}
|
||||
|
||||
func (s *JobExecuting) submitNormalTask(rtx jobmgr.JobStateRunContext, cmd string, envs []schsdk.KVPair, ccInfo schmod.ComputingCenter, pcmImgInfo schmod.PCMImage, resourceID pcmsdk.ResourceID) error {
|
||||
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewSubmitTask(
|
||||
ccInfo.PCMParticipantID,
|
||||
|
@ -147,11 +200,51 @@ func (s *JobExecuting) submitNormalTask(rtx jobmgr.JobStateRunContext, cmd strin
|
|||
}
|
||||
}
|
||||
|
||||
func (s *JobExecuting) submitFinetuningTask(rtx jobmgr.JobStateRunContext, cmd string, envs []schsdk.KVPair, ccInfo schmod.ComputingCenter) error {
|
||||
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewSchedulerModelFinetuning(
|
||||
func (s *JobExecuting) submitDataPreprocessTask(rtx jobmgr.JobStateRunContext, cmd string, envs []schsdk.KVPair, ccInfo schmod.ComputingCenter, storageID cdssdk.StorageID, userID cdssdk.UserID) (string, error) {
|
||||
objectStorage, err := rtx.Mgr.DB.ObjectStorage().GetObjectStorageByStorageID(rtx.Mgr.DB.SQLCtx(), storageID)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
return "", fmt.Errorf("getting object storage info: %w", err)
|
||||
}
|
||||
|
||||
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewSchedulerDataPreprocess(
|
||||
userID,
|
||||
cmd,
|
||||
envs,
|
||||
schglb.InferencePlatform,
|
||||
objectStorage,
|
||||
), ccInfo)
|
||||
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
taskFut := task.Receive()
|
||||
msg := <-taskFut.Chan()
|
||||
tskStatus := msg.Value.Status.(*exetsk.SchedulerDataPreprocessStatus)
|
||||
|
||||
if tskStatus.Error != nil {
|
||||
logger.Error(tskStatus.Error.Error())
|
||||
return "", tskStatus.Error
|
||||
}
|
||||
|
||||
return tskStatus.InstanceID, nil
|
||||
}
|
||||
|
||||
func (s *JobExecuting) submitFinetuningTask(userID cdssdk.UserID, rtx jobmgr.JobStateRunContext, cmd string, envs []schsdk.KVPair, ccInfo schmod.ComputingCenter, storageID cdssdk.StorageID, runningJob *job.NormalJob) error {
|
||||
|
||||
objectStorage, modelInfo, err := getModelInfoAndObjectStorage(rtx, runningJob.Info.ModelJobInfo.ModelID, storageID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting model info and object storage: %w", err)
|
||||
}
|
||||
|
||||
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewSchedulerModelFinetuning(
|
||||
userID,
|
||||
cmd,
|
||||
*objectStorage,
|
||||
*modelInfo,
|
||||
envs,
|
||||
string(runningJob.Files.Dataset.ECSInstanceID),
|
||||
), ccInfo)
|
||||
|
||||
if err != nil {
|
||||
|
@ -176,29 +269,17 @@ func (s *JobExecuting) submitInstanceTask(rtx jobmgr.JobStateRunContext, jo *job
|
|||
|
||||
modelJobInfo := runningJob.Info.ModelJobInfo
|
||||
|
||||
objectStorage, err := rtx.Mgr.DB.ObjectStorage().GetObjectStorageByStorageID(rtx.Mgr.DB.SQLCtx(), storageID)
|
||||
objectStorage, modelInfo, err := getModelInfoAndObjectStorage(rtx, modelJobInfo.ModelID, storageID)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
return fmt.Errorf("getting object storage info: %w", err)
|
||||
}
|
||||
|
||||
// 先从数据库中查询是否已经预置了模型
|
||||
modelInfo, err := rtx.Mgr.DB.Models().GetModelByID(rtx.Mgr.DB.SQLCtx(), modelJobInfo.ModelID, objectStorage.ID)
|
||||
if &modelInfo == nil {
|
||||
logger.Error(err.Error())
|
||||
return fmt.Errorf("the model is not exists: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
return fmt.Errorf("getting model info info: %w", err)
|
||||
return fmt.Errorf("getting model info and object storage: %w", err)
|
||||
}
|
||||
|
||||
// 发送扩容任务
|
||||
ecs := exetsk.NewScheduleCreateECS(
|
||||
userID,
|
||||
runningJob.Info.Runtime.Command+"\\n"+modelJobInfo.Command,
|
||||
objectStorage,
|
||||
modelInfo,
|
||||
*objectStorage,
|
||||
*modelInfo,
|
||||
envs,
|
||||
)
|
||||
task, err := rtx.Mgr.ExecMgr.StartTask(ecs, ccInfo)
|
||||
|
@ -289,6 +370,27 @@ func (s *JobExecuting) submitInstanceTask(rtx jobmgr.JobStateRunContext, jo *job
|
|||
}
|
||||
}
|
||||
|
||||
func getModelInfoAndObjectStorage(rtx jobmgr.JobStateRunContext, modelID schsdk.ModelID, storageID cdssdk.StorageID) (*schmod.ObjectStorage, *schmod.ModelResource, error) {
|
||||
objectStorage, err := rtx.Mgr.DB.ObjectStorage().GetObjectStorageByStorageID(rtx.Mgr.DB.SQLCtx(), storageID)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
return nil, nil, fmt.Errorf("getting object storage info: %w", err)
|
||||
}
|
||||
|
||||
// 先从数据库中查询是否已经预置了模型
|
||||
modelInfo, err := rtx.Mgr.DB.Models().GetModelByID(rtx.Mgr.DB.SQLCtx(), modelID, objectStorage.ID)
|
||||
if &modelInfo == nil {
|
||||
logger.Error(err.Error())
|
||||
return nil, nil, fmt.Errorf("the model is not exists: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
return nil, nil, fmt.Errorf("getting model info info: %w", err)
|
||||
}
|
||||
|
||||
return &objectStorage, &modelInfo, nil
|
||||
}
|
||||
|
||||
func postDeleteInstanceEvent(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job, runningJob *job.InstanceJob) {
|
||||
deleteInfo := event.InstanceDeleteInfo{
|
||||
InstanceID: jo.JobID,
|
||||
|
@ -299,13 +401,11 @@ func postDeleteInstanceEvent(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job, runn
|
|||
}
|
||||
|
||||
// 判断算力中心是否支持环境变量配置,如果不支持,则读取脚本内容并拼接在Command参数后面
|
||||
func getRuntimeCommand(runtime schsdk.JobRuntimeInfo, packageID cdssdk.PackageID, outputPath string, remoteBase string, ccInfo schmod.ComputingCenter) (string, []schsdk.KVPair) {
|
||||
func getRuntimeCommand(runtime schsdk.JobRuntimeInfo, dataSetPath string, outputPath string, remoteBase string, ccInfo schmod.ComputingCenter) (string, []schsdk.KVPair) {
|
||||
var envs []schsdk.KVPair
|
||||
var params []string
|
||||
var cmd string
|
||||
|
||||
// TODO 临时使用,这个路径应该来自于CDS
|
||||
dataSetPath := filepath.Join("packages", "1", fmt.Sprintf("%v", packageID))
|
||||
envs = append(envs, schsdk.KVPair{Key: schsdk.JobDataInEnv, Value: filepath.Join(remoteBase, dataSetPath)})
|
||||
envs = append(envs, schsdk.KVPair{Key: schsdk.JobDataOutEnv, Value: filepath.Join(remoteBase, outputPath)})
|
||||
envs = append(envs, runtime.Envs...)
|
||||
|
|
|
@ -43,7 +43,7 @@ func (s *MultiInstanceRunning) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job)
|
|||
|
||||
multInstJob := jo.Body.(*job.MultiInstanceJob)
|
||||
|
||||
go pollingInstance(rtx, multInstJob)
|
||||
//go pollingInstance(rtx, multInstJob)
|
||||
|
||||
waitFut := event.BeginWaitType[*event.InstanceOperate](rtx.EventSet)
|
||||
for {
|
||||
|
@ -61,6 +61,12 @@ func (s *MultiInstanceRunning) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job)
|
|||
// 微调任务特殊处理
|
||||
if info.Info.UpdateType == schsdk.FineTuning {
|
||||
multInstJob.Info.ModelJobInfo.Command = info.Info.Runtime.Command
|
||||
// 从原有配置中删除微调的输出路径,防止冲突
|
||||
for i := 0; i < len(multInstJob.Info.Runtime.Envs); i++ {
|
||||
if multInstJob.Info.Runtime.Envs[i].Key == schsdk.FinetuningOutEnv {
|
||||
multInstJob.Info.Runtime.Envs = append(multInstJob.Info.Runtime.Envs[:i], multInstJob.Info.Runtime.Envs[i+1:]...)
|
||||
}
|
||||
}
|
||||
multInstJob.Info.Runtime.Envs = append(multInstJob.Info.Runtime.Envs, info.Info.Runtime.Envs...)
|
||||
subJobs = multInstJob.SubJobs
|
||||
}
|
||||
|
@ -115,7 +121,7 @@ func updateInstance(rtx jobmgr.JobStateRunContext, updateInfo *event.InstanceUpd
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
fut := future.NewSetValue[event.UpdateResult]()
|
||||
rtx.Mgr.PostEvent(instanceID, event.NewUpdate(updateInfo.Info.Runtime, updateInfo.Info.Operate, fut))
|
||||
rtx.Mgr.PostEvent(instanceID, event.NewUpdate(updateInfo.Info.Runtime, schsdk.RestartServer, fut))
|
||||
_, err := fut.Wait(context.TODO())
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -49,7 +49,7 @@ func (s *MultiInstanceUpdate) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job)
|
|||
if instanceJob.Info.UpdateType == schsdk.FineTuning {
|
||||
var dtrJob *job.DataReturnJob
|
||||
// 等待回源任务完成
|
||||
if rt, ok := updateJob.Info.Files.Code.(*schsdk.DataReturnJobFileInfo); ok {
|
||||
if rt, ok := updateJob.Info.Files.Dataset.(*schsdk.DataReturnJobFileInfo); ok {
|
||||
evt, ok := event.WaitTypeAnd[*event.JobCompleted](ctx, rtx.EventSet, func(val *event.JobCompleted) bool {
|
||||
return val.Job.GetInfo().GetLocalJobID() == rt.DataReturnLocalJobID
|
||||
})
|
||||
|
@ -87,6 +87,9 @@ func (s *MultiInstanceUpdate) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job)
|
|||
PackageID: dtrJob.DataReturnPackageID,
|
||||
StorageID: getStg.StorageID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading package: %w", err)
|
||||
}
|
||||
logger.Info("load pacakge path: " + loadPackageResp.FullPath)
|
||||
fullPath = loadPackageResp.FullPath
|
||||
}
|
||||
|
|
|
@ -42,10 +42,6 @@ func (s *PreScheduling) Run(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) {
|
|||
jobFilesInfo = runningJob.Info.Files
|
||||
jobFiles = &runningJob.Files
|
||||
runningJob.TargetCCID = s.scheme.TargetCCID
|
||||
case *job.FinetuningJob:
|
||||
jobFilesInfo = runningJob.Info.Files
|
||||
jobFiles = &runningJob.Files
|
||||
runningJob.TargetCCID = s.scheme.TargetCCID
|
||||
case *job.MultiInstanceJob:
|
||||
jobFilesInfo = runningJob.Info.Files
|
||||
jobFiles = &runningJob.Files
|
||||
|
|
|
@ -3,7 +3,6 @@ package state
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||||
jobmod "gitlink.org.cn/cloudream/scheduler/common/models/job"
|
||||
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr"
|
||||
|
@ -35,9 +34,6 @@ func (s *ReadyToAdjust) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error
|
|||
case *job.NormalJob:
|
||||
jobFilesInfo = runningJob.Info.Files
|
||||
jobFiles = &runningJob.Files
|
||||
case *job.FinetuningJob:
|
||||
jobFilesInfo = runningJob.Info.Files
|
||||
jobFiles = &runningJob.Files
|
||||
case *job.InstanceJob:
|
||||
jobFilesInfo = runningJob.Info.Files
|
||||
jobFiles = &runningJob.Files
|
||||
|
@ -68,6 +64,7 @@ func (s *ReadyToAdjust) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error
|
|||
}
|
||||
|
||||
jobFiles.Dataset.PackageID = rtJob.DataReturnPackageID
|
||||
jobFiles.Dataset.ECSInstanceID = rtJob.ECSInstanceID
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -55,6 +55,7 @@ func (s *WaitTargetComplete) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) e
|
|||
reJob.TargetJobID = evt.Job.JobID
|
||||
reJob.TargetJobCCID = norJob.TargetCCID
|
||||
reJob.TargetJobOutputPath = norJob.OutputPath
|
||||
reJob.ECSInstanceID = norJob.ECSInstanceID
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -99,8 +99,14 @@ func (s *NodeService) GetNodeUsageRateInfo(customModelName schsdk.ModelName, mod
|
|||
}
|
||||
for i := 0; i < len(value.Nodes); i++ {
|
||||
node := value.Nodes[i]
|
||||
c := s.NodeUsageCache[node.InstanceID]
|
||||
c, ok := s.NodeUsageCache[node.InstanceID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
rateInfo := getCacheData(c)
|
||||
rateInfo.InstanceID = node.InstanceID
|
||||
rateInfo.Address = node.Address
|
||||
|
||||
rateInfos = append(rateInfos, rateInfo)
|
||||
}
|
||||
|
@ -127,6 +133,7 @@ func getCacheData(c *cache.Cache) schsdk.NodeUsageRateInfo {
|
|||
infoMap := make(map[string][]schsdk.UsageRate)
|
||||
|
||||
// 获取缓存中的所有项
|
||||
|
||||
items := c.Items()
|
||||
|
||||
// 遍历缓存项,将其放入 map 中
|
||||
|
|
|
@ -29,6 +29,7 @@ func (svc *Service) SubmitJobSet(msg *mgrmq.SubmitJobSet) (*mgrmq.SubmitJobSetRe
|
|||
switch info := jobInfo.(type) {
|
||||
case *schsdk.NormalJobInfo:
|
||||
jo := job.NewNormalJob(*info)
|
||||
jo.SubType = schsdk.JobTypeNormal
|
||||
|
||||
preSch, ok := msg.PreScheduleScheme.JobSchemes[info.LocalJobID]
|
||||
if !ok {
|
||||
|
@ -83,8 +84,40 @@ func (svc *Service) SubmitJobSet(msg *mgrmq.SubmitJobSet) (*mgrmq.SubmitJobSetRe
|
|||
InitState: state.NewMultiInstanceUpdate(multiInstanceJobDump),
|
||||
})
|
||||
|
||||
case *schsdk.DataPreprocessJobInfo:
|
||||
normalJobInfo := &schsdk.NormalJobInfo{
|
||||
Type: schsdk.JobTypeNormal,
|
||||
JobInfoBase: info.JobInfoBase,
|
||||
Files: info.Files,
|
||||
Runtime: info.Runtime,
|
||||
Services: info.Services,
|
||||
Resources: info.Resources,
|
||||
}
|
||||
jo := job.NewNormalJob(*normalJobInfo)
|
||||
jo.SubType = schsdk.JobTypeDataPreprocess
|
||||
|
||||
preSch, ok := msg.PreScheduleScheme.JobSchemes[info.LocalJobID]
|
||||
if !ok {
|
||||
return nil, mq.Failed(errorcode.OperationFailed, fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
|
||||
}
|
||||
|
||||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||||
Body: jo,
|
||||
InitState: state.NewPreSchuduling(preSch),
|
||||
})
|
||||
|
||||
case *schsdk.FinetuningJobInfo:
|
||||
jo := job.NewFinetuningJob(*info)
|
||||
normalJobInfo := &schsdk.NormalJobInfo{
|
||||
Type: schsdk.JobTypeNormal,
|
||||
Files: info.Files,
|
||||
JobInfoBase: info.JobInfoBase,
|
||||
Runtime: info.Runtime,
|
||||
Services: info.Services,
|
||||
Resources: info.Resources,
|
||||
ModelJobInfo: info.ModelJobInfo,
|
||||
}
|
||||
jo := job.NewNormalJob(*normalJobInfo)
|
||||
jo.SubType = schsdk.JobTypeFinetuning
|
||||
|
||||
preSch, ok := msg.PreScheduleScheme.JobSchemes[info.LocalJobID]
|
||||
if !ok {
|
||||
|
|
|
@ -11,13 +11,13 @@ import (
|
|||
|
||||
type JobTask[T any] struct {
|
||||
id string
|
||||
taskChan async.UnboundChannel[T]
|
||||
taskChan *async.UnboundChannel[T]
|
||||
}
|
||||
|
||||
func NewJobTask[T any]() *JobTask[T] {
|
||||
return &JobTask[T]{
|
||||
id: getTaskID(),
|
||||
taskChan: *async.NewUnboundChannel[T](),
|
||||
taskChan: async.NewUnboundChannel[T](),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,7 +43,7 @@ func (c *JobTask[T]) Send(info any) {
|
|||
}
|
||||
|
||||
func (c *JobTask[T]) Chan() *async.UnboundChannel[T] {
|
||||
return &c.taskChan
|
||||
return c.taskChan
|
||||
}
|
||||
|
||||
func (c *JobTask[T]) ID() string {
|
||||
|
|
Loading…
Reference in New Issue