125 lines
3.7 KiB
Go
125 lines
3.7 KiB
Go
package state
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
|
||
"gitlink.org.cn/cloudream/common/pkgs/future"
|
||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
|
||
"gitlink.org.cn/cloudream/common/sdks/storage/cdsapi"
|
||
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
|
||
jobmod "gitlink.org.cn/cloudream/scheduler/common/models/job"
|
||
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr"
|
||
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr/event"
|
||
"gitlink.org.cn/cloudream/scheduler/manager/internal/jobmgr/job"
|
||
)
|
||
|
||
type MultiInstanceUpdate struct {
|
||
originalJob jobmod.JobDump
|
||
}
|
||
|
||
func NewMultiInstanceUpdate(originalJob jobmod.JobDump) *MultiInstanceUpdate {
|
||
return &MultiInstanceUpdate{
|
||
originalJob: originalJob,
|
||
}
|
||
}
|
||
|
||
func (s *MultiInstanceUpdate) Run(rtx jobmgr.JobStateRunContext, job *jobmgr.Job) {
|
||
err := s.do(rtx, job)
|
||
if err != nil {
|
||
logger.Error("update multi instance failed: %s", err)
|
||
return
|
||
}
|
||
}
|
||
|
||
func (s *MultiInstanceUpdate) do(rtx jobmgr.JobStateRunContext, jo *jobmgr.Job) error {
|
||
updateJob := jo.Body.(*job.UpdateMultiInstanceJob)
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
// 监听取消事件
|
||
go func() {
|
||
event.WaitType[*event.Cancel](ctx, rtx.EventSet)
|
||
cancel()
|
||
}()
|
||
|
||
var fullPath string
|
||
instanceJob := jo.Body.(*job.UpdateMultiInstanceJob)
|
||
if instanceJob.Info.UpdateType == schsdk.FineTuning {
|
||
var dtrJob *job.DataReturnJob
|
||
// 等待回源任务完成
|
||
if rt, ok := updateJob.Info.Files.Dataset.(*schsdk.DataReturnJobFileInfo); ok {
|
||
evt, ok := event.WaitTypeAnd[*event.JobCompleted](ctx, rtx.EventSet, func(val *event.JobCompleted) bool {
|
||
return val.Job.GetInfo().GetLocalJobID() == rt.DataReturnLocalJobID
|
||
})
|
||
if !ok {
|
||
return jobmgr.ErrJobCancelled
|
||
}
|
||
if evt.Err != nil {
|
||
return fmt.Errorf("depended job %s was failed", evt.Job.JobID)
|
||
}
|
||
dtrJob, ok = evt.Job.Body.(*job.DataReturnJob)
|
||
if !ok {
|
||
return fmt.Errorf("job %s is not a DataReturn job(which is %T)", evt.Job.JobID, evt.Job)
|
||
}
|
||
}
|
||
|
||
stgCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return fmt.Errorf("new cloudream storage client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(stgCli)
|
||
|
||
ccInfo, err := rtx.Mgr.DB.ComputingCenter().GetByID(rtx.Mgr.DB.SQLCtx(), dtrJob.TargetJobCCID)
|
||
if err != nil {
|
||
return fmt.Errorf("getting computing center info: %w", err)
|
||
}
|
||
|
||
userID := cdssdk.UserID(1)
|
||
getStg, err := stgCli.StorageGet(cdsapi.StorageGet{
|
||
UserID: userID,
|
||
StorageID: ccInfo.CDSStorageID,
|
||
})
|
||
|
||
loadPackageResp, err := stgCli.StorageLoadPackage(cdsapi.StorageLoadPackageReq{
|
||
UserID: userID,
|
||
PackageID: dtrJob.DataReturnPackageID,
|
||
StorageID: getStg.StorageID,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("loading package: %w", err)
|
||
}
|
||
logger.Info("load pacakge path: " + loadPackageResp.FullPath)
|
||
fullPath = loadPackageResp.FullPath
|
||
}
|
||
|
||
// 发送事件,更新各个instance
|
||
updateJob.Info.Runtime.Envs = append(updateJob.Info.Runtime.Envs, schsdk.KVPair{Key: schsdk.FinetuningOutEnv, Value: fullPath})
|
||
updateInfo := event.InstanceUpdateInfo{
|
||
Info: updateJob.Info,
|
||
}
|
||
fut := future.NewSetValue[event.OperateInstanceResult]()
|
||
rtx.Mgr.PostEvent(s.originalJob.JobID, event.NewInstanceOperate(&updateInfo, fut))
|
||
|
||
result, err := fut.Wait(context.TODO())
|
||
|
||
if err != nil {
|
||
return err
|
||
}
|
||
println(result.JobID)
|
||
|
||
if result.Err != nil {
|
||
return fmt.Errorf("update instance failed: %s", result.OperateResult)
|
||
}
|
||
|
||
logger.Info("update instance success!")
|
||
return nil
|
||
}
|
||
|
||
func (s *MultiInstanceUpdate) Dump(ctx jobmgr.JobStateRunContext, job *jobmgr.Job) jobmod.JobStateDump {
|
||
return &jobmod.MultiInstanceUpdateDump{}
|
||
}
|