JCC-CSScheduler/schedulerMiddleware/internal/services/jobset.go

1133 lines
31 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"encoding/json"
"errors"
"fmt"
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/algorithm"
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/dataset"
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/model"
"sort"
"gitlink.org.cn/cloudream/common/pkgs/logger"
sch "gitlink.org.cn/cloudream/common/sdks/pcmscheduler"
"gitlink.org.cn/cloudream/common/sdks/storage/cdsapi"
uploadersdk "gitlink.org.cn/cloudream/common/sdks/uploader"
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
jobmod "gitlink.org.cn/cloudream/scheduler/common/models/job"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/event"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job/state"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job/state2"
"strings"
"time"
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
)
type JobSetService struct {
*Service
}
func (svc *Service) JobSetSvc() *JobSetService {
return &JobSetService{Service: svc}
}
func (svc *JobSetService) PreScheduler(jobSet schsdk.JobSetInfo) (*jobmod.JobSetPreScheduleScheme, *schsdk.JobSetFilesUploadScheme, error) {
ccs, err := svc.db.ComputingCenter().GetAll(svc.db.DefCtx())
if err != nil {
logger.Warnf("getting all computing center: %s", err.Error())
return nil, nil, err
}
schScheme, uploadScheme, err := svc.preScheduler.ScheduleJobSet(&jobSet, ccs)
if err != nil {
return nil, nil, fmt.Errorf("pre scheduling: %w", err)
}
return schScheme, uploadScheme, nil
}
func (svc *JobSetService) Upload(userID cdssdk.UserID, params sch.UploadParams, blockChainToken string) (*schsdk.JobSetID, error) {
logger.Debugf("uploading job")
// 查询数据库里维护的集群
//ccs, err := svc.db.ComputingCenter().GetAll(svc.db.DefCtx())
//if err != nil {
// logger.Warnf("getting all computing center: %s", err.Error())
// return nil, nil, err
//}
// 获取集群与存储的对应关系
//clusterMapping, err := svc.db.UploadData().GetClusterMapping(svc.db.DefCtx())
//if err != nil {
// return nil, nil, fmt.Errorf("query cluster mapping error: %w", err)
//}
//
//var storages []cdssdk.StorageID
//switch uploadPriority := params.UploadPriority.(type) {
//case *sch.Preferences:
// // 进行预调度
// clusterID, err := svc.preScheduler.ScheduleJob(uploadPriority.ResourcePriorities, clusterMapping)
// if err != nil {
// return nil, nil, fmt.Errorf("pre scheduling: %w", err)
// }
//
// storageID, ok := clusterMapping[*clusterID]
// if !ok {
// return nil, nil, fmt.Errorf("cluster %d not found", clusterID)
// }
//
// storages = append(storages, storageID)
//case *sch.SpecifyCluster:
// // 指定集群
// for _, clusterID := range uploadPriority.Clusters {
// storageID, ok := clusterMapping[clusterID]
// if !ok {
// logger.Warnf("cluster %d not found", clusterID)
// continue
// }
// storages = append(storages, storageID)
// }
//}
//
//if len(storages) == 0 {
// return nil, nil, errors.New("no storage is available")
//}
var jobs []jobmgr.SubmittingJob
jo := job.NewNormalJob(schsdk.NormalJobInfo{})
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state2.NewDataUpload(userID, params.UploadInfo, params.DataType, blockChainToken),
})
jobSetID := svc.jobMgr.SubmitJobSet(jobs)
return &jobSetID, nil
}
// Submit 提交任务集
func (svc *JobSetService) Submit(userID cdssdk.UserID, jobSet schsdk.JobSetInfo, schScheme *jobmod.JobSetPreScheduleScheme, token string) (*schsdk.JobSetID, error) {
logger.Debugf("submitting job")
var jobs []jobmgr.SubmittingJob
for _, jobInfo := range jobSet.Jobs {
switch info := jobInfo.(type) {
case *schsdk.PCMJobInfo:
jo := job.NewPCMJob(*info)
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
//InitState: state.NewPreSchuduling(preSch),
InitState: state2.NewPCMJobCreate(userID, info, token),
})
case *schsdk.NormalJobInfo:
jo := job.NewNormalJob(*info)
jo.SubType = schsdk.JobTypeNormal
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
//InitState: state.NewPreSchuduling(preSch),
})
case *schsdk.DataReturnJobInfo:
jo := job.NewDataReturnJob(*info)
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state.NewWaitTargetComplete(),
})
case *schsdk.MultiInstanceJobInfo:
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
jo := job.NewMultiInstanceJob(*info, preSch)
if !ok {
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
}
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state.NewMultiInstanceInit(),
})
case *schsdk.UpdateMultiInstanceJobInfo:
modelJob := job.NewUpdateMultiInstanceJob(*info)
instanceJobSets := svc.jobMgr.DumpJobSet(modelJob.Info.MultiInstanceJobSetID)
if len(instanceJobSets) == 0 {
return nil, errors.New(fmt.Sprintf("job set %s is not found", modelJob.Info.MultiInstanceJobSetID))
}
// 找到多实例任务本身
var multiInstanceJobDump jobmod.JobDump
for i := 0; i < len(instanceJobSets); i++ {
jobDump := instanceJobSets[i]
if _, ok := jobDump.Body.(*jobmod.MultiInstanceJobDump); ok {
multiInstanceJobDump = jobDump
break
}
}
jobs = append(jobs, jobmgr.SubmittingJob{
Body: modelJob,
InitState: state.NewMultiInstanceUpdate(multiInstanceJobDump),
})
case *schsdk.DataPreprocessJobInfo:
// 后续的调度流程跟NormalJob是一致的
normalJobInfo := &schsdk.NormalJobInfo{
Type: schsdk.JobTypeNormal,
JobInfoBase: info.JobInfoBase,
Files: info.Files,
Runtime: info.Runtime,
Services: info.Services,
Resources: info.Resources,
}
jo := job.NewNormalJob(*normalJobInfo)
jo.SubType = schsdk.JobTypeDataPreprocess
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
if !ok {
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
}
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state.NewPreSchuduling(preSch),
})
case *schsdk.FinetuningJobInfo:
// 后续的调度流程跟NormalJob是一致的
normalJobInfo := &schsdk.NormalJobInfo{
Type: schsdk.JobTypeNormal,
Files: info.Files,
JobInfoBase: info.JobInfoBase,
Runtime: info.Runtime,
Services: info.Services,
Resources: info.Resources,
ModelJobInfo: info.ModelJobInfo,
}
jo := job.NewNormalJob(*normalJobInfo)
jo.SubType = schsdk.JobTypeFinetuning
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
if !ok {
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
}
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state.NewPreSchuduling(preSch),
})
}
}
jobSetID := svc.jobMgr.SubmitJobSet(jobs)
return &jobSetID, nil
}
// LocalFileUploaded 任务集中某个文件上传完成
func (svc *JobSetService) LocalFileUploaded(jobSetID schsdk.JobSetID, localPath string, errMsg string, packageID cdssdk.PackageID, objectIDs []cdssdk.ObjectID) {
err := errors.New(errMsg)
svc.jobMgr.BroadcastEvent(jobSetID, event.NewLocalFileUploaded(localPath, err, packageID, objectIDs))
}
func (svc *JobSetService) CreateFolder(packageID cdssdk.PackageID, path string) error {
err := svc.JobSetSvc().db.UploadData().InsertFolder(svc.db.DefCtx(), packageID, path)
if err != nil {
return err
}
return nil
}
// 删除文件或文件夹
func (svc *JobSetService) DeleteFile(userID cdssdk.UserID, objectIDs []cdssdk.ObjectID) error {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new scheduler client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
err = cdsCli.Object().Delete(cdsapi.ObjectDelete{
ObjectIDs: objectIDs,
UserID: userID,
})
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
return nil
}
func (svc *JobSetService) DeleteFolder(userID cdssdk.UserID, packageID cdssdk.PackageID, path string) error {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new scheduler client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
list, err := cdsCli.Object().ListByPath(cdsapi.ObjectListByPath{
UserID: userID,
PackageID: packageID,
Path: path,
IsPrefix: true,
})
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
if len(list.Objects) > 0 {
var objectIDs []cdssdk.ObjectID
for _, obj := range list.Objects {
objectIDs = append(objectIDs, obj.ObjectID)
}
err = cdsCli.Object().Delete(cdsapi.ObjectDelete{
ObjectIDs: objectIDs,
UserID: userID,
})
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
}
err = svc.JobSetSvc().db.UploadData().DeleteFolder(svc.db.DefCtx(), packageID, path)
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
return nil
}
func (svc *JobSetService) QueryUploaded(queryParams sch.QueryData) ([]uploadersdk.Package, int, int, error) {
// 查询根目录
if queryParams.PackageID == -1 {
packages, err := svc.JobSetSvc().db.UploadData().QueryPackage(svc.db.DefCtx(), queryParams)
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
}
return packages, 0, 0, nil
}
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return nil, 0, 0, fmt.Errorf("new scheduler client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
queryListReq := cdsapi.ObjectListByPath{
UserID: queryParams.UserID,
PackageID: queryParams.PackageID,
Path: queryParams.Path,
IsPrefix: true,
}
objList, err := cdsCli.Object().ListByPath(queryListReq)
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
}
folderMap := make(map[string]cdssdk.Object)
var modifyObjs []cdssdk.Object
for _, obj := range objList.Objects {
// 去掉obj中path从0到queryParams.Path这段字符串
obj.Path = strings.TrimPrefix(obj.Path, queryParams.Path)
pathArr := strings.Split(obj.Path, "/")
if len(pathArr) > 2 {
splitPath := "/" + pathArr[1]
folderMap[splitPath] = cdssdk.Object{
ObjectID: -1,
PackageID: obj.PackageID,
Path: pathArr[1],
Size: 0,
CreateTime: obj.CreateTime,
}
continue
}
modifyObjs = append(modifyObjs, obj)
}
folders, err := svc.db.UploadData().QueryFolder(svc.db.DefCtx(), queryParams)
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
}
for _, folder := range folders {
folder.Path = strings.TrimPrefix(folder.Path, queryParams.Path)
if folder.Path == "" {
continue
}
folderMap[folder.Path] = cdssdk.Object{
ObjectID: -1,
PackageID: folder.PackageID,
Path: folder.Path,
Size: 0,
CreateTime: folder.CreateTime,
}
}
// 遍历folderMap将folderMap的值赋给objList.Objects
for _, obj := range folderMap {
modifyObjs = append(modifyObjs, obj)
}
objList.Objects = modifyObjs
// 根据orderBy字段排序
sort.Slice(objList.Objects, func(i, j int) bool {
if queryParams.OrderBy == sch.OrderByName {
return objList.Objects[i].Path < objList.Objects[j].Path
} else if queryParams.OrderBy == sch.OrderBySize {
return objList.Objects[i].Size < objList.Objects[j].Size
} else if queryParams.OrderBy == sch.OrderByTime {
return objList.Objects[i].CreateTime.Unix() < objList.Objects[j].CreateTime.Unix()
}
return false
})
totalNum := len(objList.Objects)
// 分页返回
if queryParams.PageSize > 0 {
start := (queryParams.CurrentPage - 1) * queryParams.PageSize
end := start + queryParams.PageSize
if start > totalNum {
return nil, 0, 0, nil
}
if end > totalNum {
end = totalNum
}
objList.Objects = objList.Objects[start:end]
}
totalPages := totalNum / queryParams.PageSize
var datas []uploadersdk.Package
data, err := svc.db.UploadData().QueryPackageByID(svc.db.DefCtx(), queryParams.PackageID)
if err != nil {
return nil, 0, 0, err
}
pkg := uploadersdk.Package{
PackageID: data.PackageID,
BucketID: data.BucketID,
DataType: data.DataType,
PackageName: data.PackageName,
JsonData: data.JsonData,
BindingID: data.BindingID,
UserID: data.UserID,
CreateTime: data.CreateTime,
Objects: objList.Objects,
UploadedCluster: data.UploadedCluster,
}
datas = append(datas, pkg)
return datas, totalPages, totalNum, nil
}
func (svc *JobSetService) DataBinding(id uploadersdk.DataID, userID cdssdk.UserID, info sch.DataBinding) error {
var bindingData uploadersdk.Binding
var packageIDs []cdssdk.PackageID
isCode := false
switch bd := info.(type) {
case *sch.DatasetBinding:
packages, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), bd.PackageIDs, []int64{-2})
if err != nil {
return err
}
if len(packages) == 0 {
return fmt.Errorf("no package found")
}
filePath := sch.Split + sch.DATASET + sch.Split + packages[0].PackageName
ds := &dataset.Dataset{Name: bd.Name, Description: bd.Description, FilePath: filePath, Category: dataset.CommonValue(bd.Category)}
resp, err := svc.hubClient.BindDataset("ModelArts", ds)
if err != nil {
return err
}
jsonData, err := json.Marshal(resp.Data)
if err != nil {
return err
}
content, err := json.Marshal(bd)
if err != nil {
return err
}
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content), string(jsonData))
packageIDs = bd.PackageIDs
case *sch.CodeBinding:
packages, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), []cdssdk.PackageID{bd.PackageID}, []int64{-2})
if err != nil {
return err
}
if len(packages) == 0 {
return fmt.Errorf("no package found")
}
filePath := sch.Split + sch.CODE + sch.Split + packages[0].PackageName
code := &algorithm.Algorithm{Name: bd.Name, Description: bd.Description, Engine: "TensorFlow", CodeDir: filePath, BootFile: filePath + sch.Split + bd.FilePath, Branch: "main"}
resp, err := svc.hubClient.BindAlgorithm("ModelArts", code)
if err != nil {
return err
}
jsonData, err := json.Marshal(resp.Data)
if err != nil {
return err
}
isCode = true
content, err := json.Marshal(bd)
if err != nil {
return err
}
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content), string(jsonData))
packageIDs = []cdssdk.PackageID{bd.PackageID}
case *sch.ImageBinding:
content, err := json.Marshal(bd)
if err != nil {
return err
}
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content), "")
packageIDs = bd.PackageIDs
case *sch.ModelBinding:
packages, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), bd.PackageIDs, []int64{-2})
if err != nil {
return err
}
if len(packages) == 0 {
return fmt.Errorf("no package found")
}
filePath := sch.Split + sch.MODEL + sch.Split + packages[0].PackageName
md := &model.Model{Name: bd.Name, Description: bd.Description, Type: bd.Category, Version: bd.Version, Engine: model.CommonValue(bd.Env), FilePath: filePath}
resp, err := svc.hubClient.BindModel("ModelArts", md)
if err != nil {
return err
}
jsonData, err := json.Marshal(resp.Data)
if err != nil {
return err
}
content, err := json.Marshal(bd)
if err != nil {
return err
}
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content), string(jsonData))
packageIDs = bd.PackageIDs
}
if bindingData.AccessLevel == "" {
bindingData.AccessLevel = sch.PrivateAccess
}
bindingData.CreateTime = time.Now()
bindingID, err := svc.db.UploadData().InsertOrUpdateBinding(svc.db.DefCtx(), bindingData)
if err != nil {
return err
}
for _, id := range packageIDs {
err = svc.db.UploadData().UpdatePackage(svc.db.DefCtx(), id, "", *bindingID)
if err != nil {
return err
}
// 算法类型需要进行版本管理
if isCode {
pkg, err := svc.db.UploadData().QueryPackageByID(svc.db.DefCtx(), id)
if err != nil {
return err
}
err = svc.db.UploadData().InsertPackageVersion(svc.db.DefCtx(), id, id, pkg.PackageName, 1)
if err != nil {
return err
}
}
}
return nil
}
func getBindingData(id uploadersdk.DataID, userID cdssdk.UserID, dataType string, name string, content string, jsonData string) uploadersdk.Binding {
bindingData := uploadersdk.Binding{
ID: id,
Name: name,
DataType: dataType,
UserID: userID,
Content: content,
JsonData: jsonData,
}
return bindingData
}
func (svc *JobSetService) RemoveBinding(pacakgeIDs []cdssdk.PackageID, bindingIDs []int64) error {
if len(pacakgeIDs) > 0 {
for _, id := range pacakgeIDs {
err := svc.db.UploadData().UpdatePackage(svc.db.DefCtx(), id, "", uploadersdk.DataID(-1))
if err != nil {
return err
}
}
}
if len(bindingIDs) > 0 {
err := svc.db.UploadData().DeleteBindingsByID(svc.db.DefCtx(), bindingIDs)
if err != nil {
return err
}
}
return nil
}
func (svc *JobSetService) DeleteBinding(IDs []int64) error {
err := svc.db.UploadData().DeleteBindingsByID(svc.db.DefCtx(), IDs)
if err != nil {
return err
}
return nil
}
func (svc *JobSetService) QueryBinding(dataType string, param sch.QueryBindingDataParam) ([]uploadersdk.BindingDetail, error) {
switch p := param.(type) {
case *sch.PrivateLevel:
return svc.queryPrivateBinding(p.UserID, uploadersdk.DataID(p.BindingID), dataType)
case *sch.PublicLevel:
var details []uploadersdk.BindingDetail
datas, err := svc.db.UploadData().GetPublicBindings(svc.db.DefCtx(), p.Type, dataType, p.UserID)
if err != nil {
return nil, err
}
for _, data := range datas {
var info sch.DataBinding
binding := uploadersdk.Binding{
DataType: dataType,
Content: data.Content,
}
info, err = unmarshalBinding(binding)
if err != nil {
return nil, err
}
bindingDetail := uploadersdk.BindingDetail{
ID: data.ID,
UserID: data.UserID,
UserName: data.UserName,
SSOId: data.SSOId,
Name: data.Name,
Info: info,
AccessLevel: data.AccessLevel,
}
details = append(details, bindingDetail)
}
return details, nil
case *sch.ApplyLevel:
var details []uploadersdk.BindingDetail
datas, err := svc.db.UploadData().GetApplyBindings(svc.db.DefCtx(), p.UserID, p.Type, dataType)
if err != nil {
return nil, err
}
for _, data := range datas {
var info sch.DataBinding
// 只有approved状态的数据才能看到详情
if data.Status == sch.ApprovedStatus {
binding := uploadersdk.Binding{
DataType: dataType,
Content: data.Content,
}
info, err = unmarshalBinding(binding)
if err != nil {
return nil, err
}
}
bindingDetail := uploadersdk.BindingDetail{
ID: data.ID,
UserID: data.UserID,
UserName: data.UserName,
SSOId: data.SSOId,
Name: data.Name,
Info: info,
Status: data.Status,
AccessLevel: data.AccessLevel,
}
details = append(details, bindingDetail)
}
return details, nil
}
return nil, fmt.Errorf("unknown query binding data type")
}
func (svc *JobSetService) queryPrivateBinding(userID cdssdk.UserID, bindingID uploadersdk.DataID, dataType string) ([]uploadersdk.BindingDetail, error) {
var details []uploadersdk.BindingDetail
if bindingID == -1 {
datas, err := svc.db.UploadData().GetPrivateBindings(svc.db.DefCtx(), userID, dataType)
if err != nil {
return nil, err
}
for _, data := range datas {
info, err := unmarshalBinding(data)
if err != nil {
return nil, err
}
bindingDetail := uploadersdk.BindingDetail{
ID: data.ID,
UserID: data.UserID,
Info: info,
AccessLevel: data.AccessLevel,
}
details = append(details, bindingDetail)
}
return details, nil
}
data, err := svc.db.UploadData().GetBindingsByID(svc.db.DefCtx(), bindingID)
if err != nil {
return nil, err
}
info, err := unmarshalBinding(*data)
if err != nil {
return nil, err
}
packages, err := svc.db.UploadData().QueryPackageByBindingID(svc.db.DefCtx(), bindingID)
if err != nil {
return nil, err
}
detail := uploadersdk.BindingDetail{
ID: bindingID,
UserID: data.UserID,
Packages: packages,
Info: info,
AccessLevel: data.AccessLevel,
}
details = append(details, detail)
return details, nil
}
func unmarshalBinding(data uploadersdk.Binding) (sch.DataBinding, error) {
var info sch.DataBinding
switch data.DataType {
case sch.DATASET:
var content sch.DatasetBinding
err := json.Unmarshal([]byte(data.Content), &content)
if err != nil {
return nil, err
}
info = &content
case sch.CODE:
var content sch.CodeBinding
err := json.Unmarshal([]byte(data.Content), &content)
if err != nil {
return nil, err
}
info = &content
case sch.IMAGE:
var content sch.ImageBinding
err := json.Unmarshal([]byte(data.Content), &content)
if err != nil {
return nil, err
}
info = &content
case sch.MODEL:
var content sch.ModelBinding
err := json.Unmarshal([]byte(data.Content), &content)
if err != nil {
return nil, err
}
info = &content
}
return info, nil
}
func (svc *JobSetService) CreatePackage(userID cdssdk.UserID, name string, dataType string, uploadPriority sch.UploadPriority) error {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new cds client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
bucket, err := svc.db.Access().GetBucketByUserID(svc.db.DefCtx(), userID, dataType)
if err != nil {
return fmt.Errorf("failed to get bucket: %w", err)
} else if bucket == nil {
return fmt.Errorf("bucket not found")
}
// 创建package
newPackage, err := cdsCli.Package().Create(cdsapi.PackageCreate{
UserID: userID,
BucketID: bucket.ID,
Name: name,
})
if err != nil {
return fmt.Errorf("failed to create package: %w", err)
}
pkg := uploadersdk.Package{
UserID: userID,
PackageID: newPackage.Package.PackageID,
PackageName: name,
BucketID: bucket.ID,
DataType: dataType,
UploadPriority: uploadPriority,
}
// 对Package进行预调度并写入到数据库中
err = svc.packageScheduler(pkg.PackageID, uploadPriority)
if err != nil {
return err
}
// 写入数据库存档
err = svc.JobSetSvc().db.UploadData().InsertPackage(svc.db.DefCtx(), pkg)
if err != nil {
return err
}
return nil
}
func (svc *JobSetService) packageScheduler(packageID cdssdk.PackageID, uploadPriority sch.UploadPriority) error {
clusterMapping, err := svc.db.UploadData().GetClusterMapping(svc.db.DefCtx())
if err != nil {
return fmt.Errorf("query cluster mapping error: %w", err)
}
var clusters []uploadersdk.Cluster
switch uploadPriority := uploadPriority.(type) {
case *sch.Preferences:
// 进行预调度
clusterID, err := svc.preScheduler.ScheduleJob(uploadPriority.ResourcePriorities, clusterMapping)
if err != nil {
return fmt.Errorf("pre scheduling: %w", err)
}
storageID, ok := clusterMapping[*clusterID]
if !ok {
return fmt.Errorf("cluster %d not found", clusterID)
}
cluster := uploadersdk.Cluster{
PackageID: packageID,
ClusterID: *clusterID,
StorageID: storageID,
}
clusters = append(clusters, cluster)
case *sch.SpecifyCluster:
// 指定集群
for _, clusterID := range uploadPriority.Clusters {
storageID, ok := clusterMapping[clusterID]
if !ok {
logger.Warnf("cluster %d not found", clusterID)
continue
}
cluster := uploadersdk.Cluster{
PackageID: packageID,
ClusterID: clusterID,
StorageID: storageID,
}
clusters = append(clusters, cluster)
}
}
if len(clusters) == 0 {
return errors.New("no storage is available")
}
for _, clst := range clusters {
err := svc.db.UploadData().InsertUploadedCluster(svc.db.DefCtx(), clst)
if err != nil {
return err
}
}
return nil
}
func (svc *JobSetService) DeletePackage(userID cdssdk.UserID, packageID cdssdk.PackageID) error {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new cds client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
err = cdsCli.Package().Delete(cdsapi.PackageDelete{
UserID: userID,
PackageID: packageID,
})
if err != nil {
return fmt.Errorf("delete package: %w", err)
}
err = svc.JobSetSvc().db.UploadData().DeletePackage(svc.db.DefCtx(), userID, packageID)
if err != nil {
return err
}
return nil
}
func (svc *JobSetService) QueryResource(queryResource sch.ResourceRange) ([]sch.ClusterDetail, error) {
clusterDetails, err := svc.getClusterResources()
if err != nil {
return nil, err
}
var results []sch.ClusterDetail
for _, cluster := range clusterDetails {
if cluster.Resources == nil || len(cluster.Resources) == 0 {
continue
}
ok := isAppropriateResources(cluster.Resources, queryResource)
if ok {
results = append(results, cluster)
}
}
return results, nil
}
func isAppropriateResources(resources []sch.ClusterResource, queryResource sch.ResourceRange) bool {
for _, resource := range resources {
if resource.Resource.Type == queryResource.Type {
//ok := compareResource(queryResource.GPU.Min, queryResource.GPU.Max, resource.Resource.Available.Value)
//if !ok {
// return false
//}
if resource.BaseResources == nil || len(resource.BaseResources) == 0 {
return false
}
ok := false
for _, baseResource := range resource.BaseResources {
if baseResource.Type == sch.ResourceTypeCPU {
ok = compareResource(queryResource.CPU.Min, queryResource.CPU.Max, baseResource.Available.Value)
if !ok {
return false
}
}
if baseResource.Type == sch.ResourceTypeMemory {
ok = compareResource(queryResource.Memory.Min, queryResource.Memory.Max, baseResource.Available.Value)
if !ok {
return false
}
}
if baseResource.Type == sch.ResourceTypeStorage {
ok = compareResource(queryResource.Storage.Min, queryResource.Storage.Max, baseResource.Available.Value)
if !ok {
return false
}
}
}
return true
}
}
return false
}
func compareResource(min float64, max float64, v float64) bool {
if min > max {
return false
}
if min == 0 && max == 0 {
return true
}
if v >= min && v <= max {
return true
}
return false
}
func (svc *JobSetService) ResourceRange() ([]sch.ResourceRange, error) {
clusterDetails, err := svc.getClusterResources()
if err != nil {
return nil, err
}
// 初始化一个空的 map 来存储资源类型的 Range 数据
resourceMap := make(map[sch.ResourceType]sch.ResourceRange)
// 遍历所有 ClusterDetail
for _, cluster := range clusterDetails {
var CPUValue float64
var MemValue float64
var StorageValue float64
for _, resource := range cluster.Resources {
for _, baseResource := range resource.BaseResources {
// 检查资源类型,跳过不需要统计的类型
switch baseResource.Type {
case sch.ResourceTypeCPU:
CPUValue = baseResource.Available.Value
case sch.ResourceTypeMemory:
MemValue = baseResource.Available.Value
case sch.ResourceTypeStorage:
StorageValue = baseResource.Available.Value
}
}
}
// 遍历每个 ClusterDetail 的资源列表
for _, resource := range cluster.Resources {
// 获取资源类型的 key
resourceType := resource.Resource.Type
// 获取资源的 Available Value
//availableValue := resource.Available.Value
// 获取现有的 ResourceRange
resourceRange, exists := resourceMap[resourceType]
if !exists {
// 如果该资源类型还没有添加过,初始化一个新的 Range
resourceRange = sch.ResourceRange{
Type: resourceType,
GPU: sch.Range{}, // 初始的 GPU 范围
GPUNumber: 0, // 初始的 GPU 数量
CPU: sch.Range{}, // 初始的 CPU 范围
Memory: sch.Range{}, // 初始的 Memory 范围
Storage: sch.Range{}, // 初始的 Storage 范围
}
}
if CPUValue < resourceRange.CPU.Min || resourceRange.CPU.Min == 0 {
resourceRange.CPU.Min = CPUValue
}
if CPUValue > resourceRange.CPU.Max {
resourceRange.CPU.Max = CPUValue
}
if MemValue < resourceRange.Memory.Min || resourceRange.Memory.Min == 0 {
resourceRange.Memory.Min = MemValue
}
if MemValue > resourceRange.Memory.Max {
resourceRange.Memory.Max = MemValue
}
if StorageValue < resourceRange.Storage.Min || resourceRange.Storage.Min == 0 {
resourceRange.Storage.Min = StorageValue
}
if StorageValue > resourceRange.Storage.Max {
resourceRange.Storage.Max = StorageValue
}
// 增加资源数量统计
resourceRange.GPUNumber++
// 更新 resourceMap 中对应资源类型的 ResourceRange
resourceMap[resourceType] = resourceRange
}
}
// 将 map 转换为一个 slice
var result []sch.ResourceRange
for _, rangeData := range resourceMap {
result = append(result, rangeData)
}
// 返回统计结果
return result, nil
}
func (svc *JobSetService) getClusterResources() ([]sch.ClusterDetail, error) {
schCli, err := schglb.PCMSchePool.Acquire()
if err != nil {
return nil, fmt.Errorf("new scheduler client: %w", err)
}
defer schglb.PCMSchePool.Release(schCli)
clusterMapping, err := svc.db.UploadData().GetClusterMapping(svc.db.DefCtx())
if err != nil {
return nil, fmt.Errorf("query cluster mapping: %w", err)
}
// 查询指定算力中心
clusterIDs := make([]schsdk.ClusterID, 0, len(clusterMapping))
for id := range clusterMapping {
clusterIDs = append(clusterIDs, id)
}
clusterDetails, err := schCli.GetClusterInfo(sch.GetClusterInfoReq{
IDs: clusterIDs,
})
if err != nil {
return nil, fmt.Errorf("get cluster info: %w", err)
}
if len(clusterDetails) == 0 {
return nil, errors.New("no cluster found")
}
return clusterDetails, nil
}
func (svc *JobSetService) QueryImages(IDs []int64) ([]sch.ClusterImage, error) {
images, err := svc.db.UploadData().GetImageByID(svc.db.DefCtx(), IDs)
if err != nil {
return nil, fmt.Errorf("query images: %w", err)
}
return images, nil
}
func (svc *JobSetService) UpdateCode(userID cdssdk.UserID, bucketID cdssdk.BucketID, packageID cdssdk.PackageID, packageName string) (*cdssdk.Package, error) {
// 复制package
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return nil, fmt.Errorf("new cds client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
maxVersion, err := svc.db.UploadData().GetMaxVersion(svc.db.DefCtx(), packageID)
if err != nil {
return nil, fmt.Errorf("get max version: %w", err)
}
version := maxVersion + 1
packageName = fmt.Sprintf("%s_%d", packageName, version)
cloneReq := cdsapi.PackageClone{
PackageID: packageID,
Name: packageName,
BucketID: bucketID,
UserID: userID,
}
cloneResp, err := cdsCli.Package().Clone(cloneReq)
if err != nil {
return nil, fmt.Errorf("clone package: %w", err)
}
// 将package添加到version表
err = svc.db.UploadData().InsertPackageVersion(svc.db.DefCtx(), packageID, cloneResp.Package.PackageID, cloneResp.Package.Name, version)
if err != nil {
return nil, fmt.Errorf("insert package version: %w", err)
}
// 返回package
return &cloneResp.Package, nil
}