JCC-CSScheduler/manager/internal/jobmgr/job/state/executing.go

523 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package state
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/samber/lo"
"gitlink.org.cn/cloudream/common/pkgs/future"
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
"gitlink.org.cn/cloudream/common/sdks/storage/cdsapi"
schmod "gitlink.org.cn/cloudream/scheduler/common/models"
"gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor"
"gitlink.org.cn/cloudream/scheduler/manager/internal/executormgr"
"gitlink.org.cn/cloudream/common/pkgs/logger"
pcmsdk "gitlink.org.cn/cloudream/common/sdks/pcm"
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
jobmod "gitlink.org.cn/cloudream/scheduler/common/models/job"
exetsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor/task"
"gitlink.org.cn/cloudream/scheduler/common/utils"
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr"
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr/event"
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr/job"
)
type JobExecuting struct {
lastStatus pcmsdk.TaskStatus
}
func NewNormalJobExecuting() *JobExecuting {
return &JobExecuting{
lastStatus: "Begin",
}
}
func (s *JobExecuting) Run(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) {
err := s.do(rtx, jo)
if err != nil {
rtx.Mgr.ChangeState(jo, FailureComplete(err))
} else {
rtx.Mgr.ChangeState(jo, SuccessComplete())
}
}
func (s *JobExecuting) Dump(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) jobmod.JobStateDump {
return &jobmod.NormalJobExecutingDump{
TaskStatus: s.lastStatus,
}
}
func (s *JobExecuting) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error {
// TODO UserID
userID := cdssdk.UserID(1)
err := error(nil)
switch runningJob := jo.Body.(type) {
case *job.NormalJob:
switch runningJob.SubType {
case schsdk.JobTypeNormal: // 普通任务
pcmImgInfo, err := rtx.Mgr.DB.PCMImage().GetByImageIDAndCCID(rtx.Mgr.DB.SQLCtx(), runningJob.Files.Image.ImageID, runningJob.TargetCCID)
if err != nil {
return fmt.Errorf("getting pcm image info: %w", err)
}
ress, err := rtx.Mgr.DB.CCResource().GetByCCID(rtx.Mgr.DB.SQLCtx(), runningJob.TargetCCID)
if err != nil {
return fmt.Errorf("getting computing center resource: %w", err)
}
if len(ress) == 0 {
return fmt.Errorf("no resource found at computing center %v", runningJob.TargetCCID)
}
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
if err != nil {
return fmt.Errorf("getting storage info: %w", err)
}
dataSetPath := getDataSetPathByID(runningJob.Files.Dataset.PackageID)
cmd, envs := getRuntimeCommand(runningJob.Info.Runtime, dataSetPath, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
err = s.submitNormalTask(rtx, cmd, envs, *ccInfo, pcmImgInfo, ress[0].PCMResourceID)
if err != nil {
logger.Error(err.Error())
}
case schsdk.JobTypeDataPreprocess: // 数据预处理
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
if err != nil {
return fmt.Errorf("getting storage info: %w", err)
}
dataSetPath := getDataSetPathByID(runningJob.Files.Dataset.PackageID)
cmd, envs := getRuntimeCommand(runningJob.Info.Runtime, dataSetPath, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
instID, err := s.submitDataPreprocessTask(rtx, cmd, envs, *ccInfo, getStg.StorageID, userID)
if err != nil {
logger.Error(err.Error())
}
runningJob.ECSInstanceID = schsdk.ECSInstanceID(instID)
case schsdk.JobTypeFinetuning: // 模型微调
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
if err != nil {
return fmt.Errorf("getting storage info: %w", err)
}
dataSetPath := getDataSetPathByID(runningJob.Files.Dataset.PackageID)
// 将整理的数据集提交到OSS
if runningJob.Files.Dataset.ECSInstanceID != "" {
logger.Infof("instance id: %v", runningJob.ECSInstanceID)
dataSetPath, err = loadDatasetPackage(userID, runningJob.Files.Dataset.PackageID, getStg.StorageID)
if err != nil {
return fmt.Errorf("loading dataset package: %w", err)
}
}
cmd, envs := getRuntimeCommand(runningJob.Info.Runtime, dataSetPath, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
err = s.submitFinetuningTask(userID, rtx, cmd, envs, *ccInfo, getStg.StorageID, runningJob)
if err != nil {
logger.Error(err.Error())
}
}
case *job.InstanceJob: // 推理任务
ccInfo, getStg, err := getCCInfoAndStgInfo(rtx, runningJob.TargetCCID, userID)
if err != nil {
return fmt.Errorf("getting storage info: %w", err)
}
dataSetPath := getDataSetPathByID(runningJob.Files.Dataset.PackageID)
_, envs := getRuntimeCommand(runningJob.Info.Runtime, dataSetPath, runningJob.OutputPath, getStg.RemoteBase, *ccInfo)
err = s.submitInstanceTask(rtx, jo, runningJob, *ccInfo, getStg.StorageID, userID, envs)
if err != nil {
logger.Error(err.Error())
// 创建失败,从多实例任务中删除
postDeleteInstanceEvent(rtx, jo, runningJob)
}
}
return err
}
func getDataSetPathByID(packageID cdssdk.PackageID) string {
// TODO 临时使用这个路径应该来自于CDS
dataSetPath := filepath.Join("packages", "1", fmt.Sprintf("%v", packageID))
return dataSetPath
}
func loadDatasetPackage(userID cdssdk.UserID, packageID cdssdk.PackageID, storageID cdssdk.StorageID) (string, error) {
stgCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return "", err
}
defer schglb.CloudreamStoragePool.Release(stgCli)
loadPackageResp, err := stgCli.StorageLoadPackage(cdsapi.StorageLoadPackageReq{
UserID: userID,
PackageID: packageID,
StorageID: storageID,
})
if err != nil {
return "", err
}
logger.Info("load pacakge path: " + loadPackageResp.FullPath)
return loadPackageResp.FullPath, nil
}
func (s *JobExecuting) submitNormalTask(rtx jobmgr.JobStateRunContext, cmd string, envs []schsdk.KVPair, ccInfo schmod.ComputingCenter, pcmImgInfo schmod.PCMImage, resourceID pcmsdk.ResourceID) error {
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewSubmitTask(
ccInfo.PCMParticipantID,
pcmImgInfo.PCMImageID,
// TODO 选择资源的算法
resourceID,
cmd,
envs,
// params, TODO params不应该是kv数组而应该是字符串数组
[]schsdk.KVPair{},
), ccInfo)
if err != nil {
logger.Error(err.Error())
return err
}
taskFut := task.Receive()
for {
msg := <-taskFut.Chan()
tskStatus := msg.Value.Status.(*exetsk.SubmitTaskStatus)
if tskStatus.Status != s.lastStatus {
logger.Infof("task %s -> %s", s.lastStatus, tskStatus.Status)
}
s.lastStatus = tskStatus.Status
switch tskStatus.Status {
case pcmsdk.TaskStatusSuccess:
return nil
case "Completed":
return nil
case pcmsdk.TaskStatusFailed:
return fmt.Errorf("task failed")
}
taskFut = task.Receive()
}
}
func (s *JobExecuting) submitDataPreprocessTask(rtx jobmgr.JobStateRunContext, cmd string, envs []schsdk.KVPair, ccInfo schmod.ComputingCenter, storageID cdssdk.StorageID, userID cdssdk.UserID) (string, error) {
objectStorage, err := rtx.Mgr.DB.ObjectStorage().GetObjectStorageByStorageID(rtx.Mgr.DB.SQLCtx(), storageID)
if err != nil {
logger.Error(err.Error())
return "", fmt.Errorf("getting object storage info: %w", err)
}
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewSchedulerDataPreprocess(
userID,
cmd,
envs,
objectStorage,
), ccInfo)
if err != nil {
logger.Error(err.Error())
return "", err
}
taskFut := task.Receive()
msg := <-taskFut.Chan()
tskStatus := msg.Value.Status.(*exetsk.SchedulerDataPreprocessStatus)
if tskStatus.Error != nil {
logger.Error(tskStatus.Error.Error())
return "", tskStatus.Error
}
return tskStatus.InstanceID, nil
}
func (s *JobExecuting) submitFinetuningTask(userID cdssdk.UserID, rtx jobmgr.JobStateRunContext, cmd string, envs []schsdk.KVPair, ccInfo schmod.ComputingCenter, storageID cdssdk.StorageID, runningJob *job.NormalJob) error {
objectStorage, modelInfo, err := getModelInfoAndObjectStorage(rtx, runningJob.Info.ModelJobInfo.ModelID, storageID)
if err != nil {
return fmt.Errorf("getting model info and object storage: %w", err)
}
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewSchedulerModelFinetuning(
userID,
cmd,
*objectStorage,
*modelInfo,
envs,
string(runningJob.Files.Dataset.ECSInstanceID),
), ccInfo)
if err != nil {
logger.Error(err.Error())
return err
}
taskFut := task.Receive()
msg := <-taskFut.Chan()
tskStatus := msg.Value.Status.(*exetsk.SchedulerModelFinetuningStatus)
if tskStatus.Error != nil {
logger.Error(tskStatus.Error.Error())
return tskStatus.Error
}
return nil
}
func (s *JobExecuting) submitInstanceTask(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job, runningJob *job.InstanceJob, ccInfo schmod.ComputingCenter,
storageID cdssdk.StorageID, userID cdssdk.UserID, envs []schsdk.KVPair) error {
modelJobInfo := runningJob.Info.ModelJobInfo
objectStorage, modelInfo, err := getModelInfoAndObjectStorage(rtx, modelJobInfo.ModelID, storageID)
if err != nil {
return fmt.Errorf("getting model info and object storage: %w", err)
}
// 发送扩容任务
ecs := exetsk.NewScheduleCreateECS(
userID,
runningJob.Info.Runtime.Command+"\\n"+modelJobInfo.Command,
*objectStorage,
*modelInfo,
envs,
)
task, err := rtx.Mgr.ExecMgr.StartTask(ecs, ccInfo)
if err != nil {
logger.Error(err.Error())
return err
}
waitFut := event.BeginWaitType[*event.Update](rtx.EventSet)
taskFut := task.Receive()
for {
select {
case v1 := <-waitFut.Chan():
// 对任务进行更新操作
client, err := executormgr.ExecutorPool.AcquireByUrl(ccInfo.ExecutorURL)
if err != nil {
return fmt.Errorf("getting executor client: %w", err)
}
evt := v1.Value.(*event.Update)
operateResp, err := client.OperateTask(executor.NewTaskOperateInfo(task.ID(), evt.Operate, evt.Runtime))
if err != nil {
return fmt.Errorf("operate task: %w", err)
}
evt.Result.SetValue(event.UpdateResult{
Err: operateResp.Err,
})
if operateResp.Err != nil {
return fmt.Errorf("operate task: %w", operateResp.Err)
}
// 持续等待
waitFut = event.BeginWaitType[*event.Update](rtx.EventSet)
case msg := <-taskFut.Chan():
switch v2 := msg.Value.Status.(type) {
case *exetsk.ScheduleCreateECSStatus:
if v2.Error != "" {
logger.Error("update task fail, error: " + v2.Error)
if v2.Operate == schsdk.CreateECS || v2.Operate == schsdk.Invalid {
// 创建失败或者检测不可用,从多实例任务中删除
v2.Operate = schsdk.DestroyECS
} else {
continue
}
}
switch v2.Operate {
case schsdk.CreateECS:
// 扩容任务,将结果放到池子中
node := schsdk.NodeInfo{
InstanceID: jo.JobID,
Address: schsdk.Address(v2.Result),
Status: schsdk.RunECS,
}
rtx.Mgr.NodeSvc.SetNodeData(jo.JobSetID, modelJobInfo, node)
logger.Infof("node expansion: %v", v2.Result)
case schsdk.DestroyECS:
// 缩容任务,从节点列表中移除
rtx.Mgr.NodeSvc.RemoveNodeFromRunningModels(modelJobInfo, jo.JobID)
// 从多实例任务中删除
postDeleteInstanceEvent(rtx, jo, runningJob)
case schsdk.PauseECS:
// 更新节点状态
rtx.Mgr.NodeSvc.UpdateNodeFromRunningModels(modelJobInfo, jo.JobID, schsdk.PauseECS)
case schsdk.RunECS:
// 更新节点状态
rtx.Mgr.NodeSvc.UpdateNodeFromRunningModels(modelJobInfo, jo.JobID, schsdk.RunECS)
case schsdk.OperateServer:
println()
case schsdk.GPUMonitor:
rtx.Mgr.NodeSvc.SetNodeUsageRateInfo(jo.JobID, v2.Result)
}
case error:
fmt.Println("Received error:", v2.Error())
default:
fmt.Println("Received unexpected type")
}
// 持续接收
taskFut = task.Receive()
}
}
}
func getModelInfoAndObjectStorage(rtx jobmgr.JobStateRunContext, modelID schsdk.ModelID, storageID cdssdk.StorageID) (*schmod.ObjectStorage, *schmod.ModelResource, error) {
objectStorage, err := rtx.Mgr.DB.ObjectStorage().GetObjectStorageByStorageID(rtx.Mgr.DB.SQLCtx(), storageID)
if err != nil {
logger.Error(err.Error())
return nil, nil, fmt.Errorf("getting object storage info: %w", err)
}
// 先从数据库中查询是否已经预置了模型
modelInfo, err := rtx.Mgr.DB.Models().GetModelByID(rtx.Mgr.DB.SQLCtx(), modelID, objectStorage.ID)
if &modelInfo == nil {
logger.Error(err.Error())
return nil, nil, fmt.Errorf("the model is not exists: %w", err)
}
if err != nil {
logger.Error(err.Error())
return nil, nil, fmt.Errorf("getting model info info: %w", err)
}
return &objectStorage, &modelInfo, nil
}
func postDeleteInstanceEvent(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job, runningJob *job.InstanceJob) {
deleteInfo := event.InstanceDeleteInfo{
InstanceID: jo.JobID,
}
fut := future.NewSetValue[event.OperateInstanceResult]()
rtx.Mgr.PostEvent(runningJob.ParentJobID, event.NewInstanceOperate(&deleteInfo, fut))
_, _ = fut.Wait(context.TODO())
}
// 判断算力中心是否支持环境变量配置如果不支持则读取脚本内容并拼接在Command参数后面
func getRuntimeCommand(runtime schsdk.JobRuntimeInfo, dataSetPath string, outputPath string, remoteBase string, ccInfo schmod.ComputingCenter) (string, []schsdk.KVPair) {
var envs []schsdk.KVPair
var params []string
var cmd string
envs = append(envs, schsdk.KVPair{Key: schsdk.JobDataInEnv, Value: filepath.Join(remoteBase, dataSetPath)})
envs = append(envs, schsdk.KVPair{Key: schsdk.JobDataOutEnv, Value: filepath.Join(remoteBase, outputPath)})
envs = append(envs, runtime.Envs...)
switch boot := ccInfo.Bootstrap.(type) {
case *schsdk.DirectBootstrap:
cmd = runtime.Command
case *schsdk.NoEnvBootstrap:
cmd = boot.ScriptFileName
params = append(params, runtime.Command)
envMap := lo.Map(envs, func(env schsdk.KVPair, _ int) string {
return fmt.Sprintf("%s=%s", env.Key, env.Value)
})
params = append(params, envMap...)
default:
cmd = runtime.Command
}
return cmd, envs
}
func getCCInfoAndStgInfo(rtx jobmgr.JobStateRunContext, targetCCID schsdk.CCID, userID cdssdk.UserID) (*schmod.ComputingCenter, *cdsapi.StorageGetResp, error) {
ccInfo, err := rtx.Mgr.DB.ComputingCenter().GetByID(rtx.Mgr.DB.SQLCtx(), targetCCID)
if err != nil {
return nil, nil, fmt.Errorf("getting computing center info: %w", err)
}
stgCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return nil, nil, fmt.Errorf("new cds client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(stgCli)
getStg, err := stgCli.StorageGet(cdsapi.StorageGet{
UserID: userID,
StorageID: ccInfo.CDSStorageID,
})
if err != nil {
return nil, nil, fmt.Errorf("request to cds: %w", err)
}
return &ccInfo, getStg, nil
}
type DataReturnJobExecuting struct {
}
func NewDataReturnJobExecuting() *DataReturnJobExecuting {
return &DataReturnJobExecuting{}
}
func (s *DataReturnJobExecuting) Run(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) {
err := s.do(rtx, jo)
if err != nil {
rtx.Mgr.ChangeState(jo, FailureComplete(err))
} else {
rtx.Mgr.ChangeState(jo, SuccessComplete())
}
}
func (s *DataReturnJobExecuting) Dump(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) jobmod.JobStateDump {
return &jobmod.DataReturnExecutingDump{}
}
func (s *DataReturnJobExecuting) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error {
reJob := jo.Body.(*job.DataReturnJob)
userID := cdssdk.UserID(1)
log := logger.WithType[JobExecuting]("State").WithField("JobID", jo.JobID)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 监听取消事件
go func() {
event.WaitType[*event.Cancel](ctx, rtx.EventSet)
cancel()
}()
ccInfo, err := rtx.Mgr.DB.ComputingCenter().GetByID(rtx.Mgr.DB.SQLCtx(), reJob.TargetJobCCID)
if err != nil {
return fmt.Errorf("getting computing center info: %w", err)
}
packageName := utils.MakeResourcePackageName(reJob.TargetJobID)
logger.Info("TargetJobOutputPath: " + reJob.TargetJobOutputPath + ", and packageName: " + packageName)
time.Sleep(30 * time.Second)
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewStorageCreatePackage(
userID, // TOOD 用户ID
ccInfo.CDSStorageID,
reJob.TargetJobOutputPath,
reJob.Info.BucketID,
packageName,
), ccInfo)
if err != nil {
log.Error(err.Error())
return err
}
fut := task.Receive()
if err != nil {
return err
}
status := <-fut.Chan()
tskStatus := status.Value.Status.(*exetsk.StorageCreatePackageStatus)
if tskStatus.Error != "" {
return fmt.Errorf("creating package: %s", tskStatus.Error)
}
log.Infof("the outputs of job %v has been updated as a package %v(%v)", reJob.TargetJobID, packageName, tskStatus.PackageID)
reJob.DataReturnPackageID = tskStatus.PackageID
return nil
}