317 lines
9.1 KiB
Go
317 lines
9.1 KiB
Go
package state
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"github.com/samber/lo"
|
||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
|
||
schmod "gitlink.org.cn/cloudream/scheduler/common/models"
|
||
"gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor"
|
||
mgrmq "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/manager"
|
||
"gitlink.org.cn/cloudream/scheduler/manager/internal/executormgr"
|
||
jobTask "gitlink.org.cn/cloudream/scheduler/manager/internal/task"
|
||
"path/filepath"
|
||
|
||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||
pcmsdk "gitlink.org.cn/cloudream/common/sdks/pcm"
|
||
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
|
||
jobmod "gitlink.org.cn/cloudream/scheduler/common/models/job"
|
||
exetsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor/task"
|
||
"gitlink.org.cn/cloudream/scheduler/common/utils"
|
||
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr"
|
||
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr/event"
|
||
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr/job"
|
||
)
|
||
|
||
type NormalJobExecuting struct {
|
||
lastStatus pcmsdk.TaskStatus
|
||
}
|
||
|
||
func NewNormalJobExecuting() *NormalJobExecuting {
|
||
return &NormalJobExecuting{
|
||
lastStatus: "Begin",
|
||
}
|
||
}
|
||
|
||
func (s *NormalJobExecuting) Run(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) {
|
||
err := s.do(rtx, jo)
|
||
if err != nil {
|
||
rtx.Mgr.ChangeState(jo, FailureComplete(err))
|
||
} else {
|
||
rtx.Mgr.ChangeState(jo, SuccessComplete())
|
||
}
|
||
}
|
||
|
||
func (s *NormalJobExecuting) Dump(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) jobmod.JobStateDump {
|
||
return &jobmod.NormalJobExecutingDump{
|
||
TaskStatus: s.lastStatus,
|
||
}
|
||
}
|
||
|
||
func (s *NormalJobExecuting) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error {
|
||
log := logger.WithType[NormalJobExecuting]("State").WithField("JobID", jo.JobID)
|
||
|
||
// TODO UserID
|
||
userID := cdssdk.UserID(1)
|
||
|
||
var runtime *schsdk.JobRuntimeInfo
|
||
var jobFiles *jobmod.JobFiles
|
||
var targetCCID schsdk.CCID
|
||
var outputPath string
|
||
var modelJobInfo *schsdk.ModelJobInfo
|
||
var packageID cdssdk.PackageID
|
||
|
||
switch runningJob := jo.Body.(type) {
|
||
case *job.NormalJob:
|
||
runtime = &runningJob.Info.Runtime
|
||
jobFiles = &runningJob.Files
|
||
targetCCID = runningJob.TargetCCID
|
||
outputPath = runningJob.OutputPath
|
||
packageID = runningJob.Files.Dataset.PackageID
|
||
case *job.InstanceJob:
|
||
runtime = &runningJob.Info.Runtime
|
||
jobFiles = &runningJob.Files
|
||
targetCCID = runningJob.TargetCCID
|
||
outputPath = runningJob.OutputPath
|
||
modelJobInfo = &runningJob.Info.ModelJobInfo
|
||
packageID = runningJob.Files.Dataset.PackageID
|
||
}
|
||
|
||
pcmImgInfo, err := rtx.Mgr.DB.PCMImage().GetByImageIDAndCCID(rtx.Mgr.DB.SQLCtx(), jobFiles.Image.ImageID, targetCCID)
|
||
if err != nil {
|
||
return fmt.Errorf("getting pcm image info: %w", err)
|
||
}
|
||
|
||
ccInfo, err := rtx.Mgr.DB.ComputingCenter().GetByID(rtx.Mgr.DB.SQLCtx(), targetCCID)
|
||
if err != nil {
|
||
return fmt.Errorf("getting computing center info: %w", err)
|
||
}
|
||
|
||
ress, err := rtx.Mgr.DB.CCResource().GetByCCID(rtx.Mgr.DB.SQLCtx(), targetCCID)
|
||
if err != nil {
|
||
return fmt.Errorf("getting computing center resource: %w", err)
|
||
}
|
||
if len(ress) == 0 {
|
||
return fmt.Errorf("no resource found at computing center %v", targetCCID)
|
||
}
|
||
|
||
// TODO 判断是否是模型推理任务,如果是,则进行扩缩容管理
|
||
if modelJobInfo != nil {
|
||
// 发送扩容任务
|
||
ecs := exetsk.NewScheduleCreateECS(
|
||
userID,
|
||
packageID,
|
||
schsdk.ModelID(modelJobInfo.ModelID),
|
||
)
|
||
task, err := rtx.Mgr.ExecMgr.StartTask(ecs, ccInfo)
|
||
|
||
if err != nil {
|
||
log.Error(err.Error())
|
||
return err
|
||
}
|
||
|
||
return s.listen(rtx, jo, task, ccInfo)
|
||
}
|
||
|
||
stgCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return fmt.Errorf("new cds client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(stgCli)
|
||
getStg, err := stgCli.StorageGet(cdssdk.StorageGet{
|
||
UserID: userID,
|
||
StorageID: ccInfo.CDSStorageID,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("request to cds: %w", err)
|
||
}
|
||
|
||
// 判断算力中心是否支持环境变量配置,如果不支持,则读取脚本内容并拼接在Command参数后面
|
||
var envs []schsdk.KVPair
|
||
var params []string
|
||
var cmd string
|
||
|
||
// TODO 临时使用,这个路径应该来自于CDS
|
||
dataSetPath := filepath.Join("packages", "1", fmt.Sprintf("%v", jobFiles.Dataset.PackageID))
|
||
envs = append(envs, schsdk.KVPair{Key: schsdk.JobDataInEnv, Value: filepath.Join(getStg.RemoteBase, dataSetPath)})
|
||
envs = append(envs, schsdk.KVPair{Key: schsdk.JobDataOutEnv, Value: filepath.Join(getStg.RemoteBase, outputPath)})
|
||
envs = append(envs, runtime.Envs...)
|
||
switch boot := ccInfo.Bootstrap.(type) {
|
||
case *schsdk.DirectBootstrap:
|
||
cmd = runtime.Command
|
||
case *schsdk.NoEnvBootstrap:
|
||
cmd = boot.ScriptFileName
|
||
params = append(params, runtime.Command)
|
||
envMap := lo.Map(envs, func(env schsdk.KVPair, _ int) string {
|
||
return fmt.Sprintf("%s=%s", env.Key, env.Value)
|
||
})
|
||
params = append(params, envMap...)
|
||
default:
|
||
cmd = runtime.Command
|
||
}
|
||
|
||
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewSubmitTask(
|
||
ccInfo.PCMParticipantID,
|
||
pcmImgInfo.PCMImageID,
|
||
// TODO 选择资源的算法
|
||
ress[0].PCMResourceID,
|
||
cmd,
|
||
envs,
|
||
// params, TODO params不应该是kv数组,而应该是字符串数组
|
||
[]schsdk.KVPair{},
|
||
), ccInfo)
|
||
|
||
if err != nil {
|
||
log.Error(err.Error())
|
||
return err
|
||
}
|
||
|
||
taskFut := task.Receive()
|
||
for {
|
||
msg := <-taskFut.Chan()
|
||
tskStatus := msg.Value.Status.(*exetsk.SubmitTaskStatus)
|
||
|
||
if tskStatus.Status != s.lastStatus {
|
||
log.Infof("task %s -> %s", s.lastStatus, tskStatus.Status)
|
||
}
|
||
s.lastStatus = tskStatus.Status
|
||
|
||
switch tskStatus.Status {
|
||
case pcmsdk.TaskStatusSuccess:
|
||
return nil
|
||
|
||
case "Completed":
|
||
return nil
|
||
|
||
case pcmsdk.TaskStatusFailed:
|
||
return fmt.Errorf("task failed")
|
||
}
|
||
|
||
taskFut = task.Receive()
|
||
}
|
||
}
|
||
|
||
func (s *NormalJobExecuting) listen(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job, task *jobTask.JobTask[mgrmq.ExecutorTaskStatus], ccInfo schmod.ComputingCenter) error {
|
||
log := logger.WithType[NormalJobExecuting]("State").WithField("TaskID", task.ID())
|
||
|
||
waitFut := event.BeginWaitType[*event.Update](rtx.EventSet)
|
||
taskFut := task.Receive()
|
||
|
||
for {
|
||
select {
|
||
case v1 := <-waitFut.Chan():
|
||
// 对任务进行更新操作
|
||
client, err := executormgr.ExecutorPool.AcquireByUrl(ccInfo.ExecutorURL)
|
||
if err != nil {
|
||
return fmt.Errorf("getting executor client: %w", err)
|
||
}
|
||
evt := v1.Value.(*event.Update)
|
||
operateResp, err := client.OperateTask(executor.NewTaskOperateInfo(task.ID(), evt.Command))
|
||
if err != nil {
|
||
return fmt.Errorf("operate task: %w", err)
|
||
}
|
||
|
||
evt.Result.SetValue(event.UpdateResult{
|
||
Err: operateResp.Err,
|
||
})
|
||
|
||
if operateResp.Err != nil {
|
||
return fmt.Errorf("operate task: %w", operateResp.Err)
|
||
}
|
||
|
||
// 持续等待
|
||
waitFut = event.BeginWaitType[*event.Update](rtx.EventSet)
|
||
case msg := <-taskFut.Chan():
|
||
switch v2 := msg.Value.Status.(type) {
|
||
case *exetsk.ScheduleCreateECSStatus:
|
||
// 扩容任务,将结果放到池子中
|
||
node := schsdk.NodeInfo{
|
||
InstanceID: jo.JobID,
|
||
Address: schsdk.Address(v2.Address),
|
||
}
|
||
|
||
jobmgr.SetNodeData(schsdk.JobID(jo.JobSetID), v2.ModelID, node)
|
||
log.Infof("node expansion: %v", v2.Address)
|
||
case error:
|
||
fmt.Println("Received error:", v2.Error())
|
||
default:
|
||
fmt.Println("Received unexpected type")
|
||
}
|
||
|
||
// 持续接收
|
||
taskFut = task.Receive()
|
||
|
||
}
|
||
}
|
||
}
|
||
|
||
type DataReturnJobExecuting struct {
|
||
}
|
||
|
||
func NewDataReturnJobExecuting() *DataReturnJobExecuting {
|
||
return &DataReturnJobExecuting{}
|
||
}
|
||
|
||
func (s *DataReturnJobExecuting) Run(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) {
|
||
err := s.do(rtx, jo)
|
||
if err != nil {
|
||
rtx.Mgr.ChangeState(jo, FailureComplete(err))
|
||
} else {
|
||
rtx.Mgr.ChangeState(jo, SuccessComplete())
|
||
}
|
||
}
|
||
|
||
func (s *DataReturnJobExecuting) Dump(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) jobmod.JobStateDump {
|
||
return &jobmod.DataReturnExecutingDump{}
|
||
}
|
||
|
||
func (s *DataReturnJobExecuting) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error {
|
||
reJob := jo.Body.(*job.DataReturnJob)
|
||
userID := cdssdk.UserID(1)
|
||
|
||
log := logger.WithType[NormalJobExecuting]("State").WithField("JobID", jo.JobID)
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
// 监听取消事件
|
||
go func() {
|
||
event.WaitType[*event.Cancel](ctx, rtx.EventSet)
|
||
cancel()
|
||
}()
|
||
|
||
ccInfo, err := rtx.Mgr.DB.ComputingCenter().GetByID(rtx.Mgr.DB.SQLCtx(), reJob.TargetJobCCID)
|
||
if err != nil {
|
||
return fmt.Errorf("getting computing center info: %w", err)
|
||
}
|
||
|
||
packageName := utils.MakeResourcePackageName(reJob.TargetJobID)
|
||
task, err := rtx.Mgr.ExecMgr.StartTask(exetsk.NewStorageCreatePackage(
|
||
userID, // TOOD 用户ID
|
||
ccInfo.CDSStorageID,
|
||
reJob.TargetJobOutputPath,
|
||
reJob.Info.BucketID,
|
||
packageName,
|
||
), ccInfo)
|
||
if err != nil {
|
||
log.Error(err.Error())
|
||
return err
|
||
}
|
||
|
||
fut := task.Receive()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
status := <-fut.Chan()
|
||
tskStatus := status.Value.Status.(*exetsk.StorageCreatePackageStatus)
|
||
if tskStatus.Error != "" {
|
||
return fmt.Errorf("creating package: %s", tskStatus.Error)
|
||
}
|
||
|
||
log.Infof("the outputs of job %v has been updated as a package %v(%v)", reJob.TargetJobID, packageName, tskStatus.PackageID)
|
||
|
||
reJob.DataReturnPackageID = tskStatus.PackageID
|
||
return nil
|
||
}
|