JCC-CSScheduler/schedulerMiddleware/internal/services/jobset.go

1446 lines
40 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"encoding/json"
"errors"
"fmt"
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/algorithm"
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/dataset"
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/model"
jobTask "gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/task"
"sort"
"strconv"
"gitlink.org.cn/cloudream/common/pkgs/logger"
sch "gitlink.org.cn/cloudream/common/sdks/pcmscheduler"
"gitlink.org.cn/cloudream/common/sdks/storage/cdsapi"
uploadersdk "gitlink.org.cn/cloudream/common/sdks/uploader"
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
jobmod "gitlink.org.cn/cloudream/scheduler/common/models/job"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/event"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job/state"
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job/state2"
"strings"
"time"
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
)
type JobSetService struct {
*Service
}
func (svc *Service) JobSetSvc() *JobSetService {
return &JobSetService{Service: svc}
}
func (svc *JobSetService) PreScheduler(jobSet schsdk.JobSetInfo) (*jobmod.JobSetPreScheduleScheme, *schsdk.JobSetFilesUploadScheme, error) {
ccs, err := svc.db.ComputingCenter().GetAll(svc.db.DefCtx())
if err != nil {
logger.Warnf("getting all computing center: %s", err.Error())
return nil, nil, err
}
schScheme, uploadScheme, err := svc.preScheduler.ScheduleJobSet(&jobSet, ccs)
if err != nil {
return nil, nil, fmt.Errorf("pre scheduling: %w", err)
}
return schScheme, uploadScheme, nil
}
func (svc *JobSetService) Upload(userID cdssdk.UserID, packageID cdssdk.PackageID, params sch.UploadParams, blockChainToken string, task *jobTask.JobTask[sch.TaskMessage]) (*schsdk.JobSetID, error) {
logger.Debugf("uploading job")
// 根据packageID 查询出对应的bucketID
pkg, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), packageID)
if err != nil {
logger.Warnf("getting upload data: %s", err.Error())
return nil, err
}
if pkg.PackageID == 0 {
return nil, errors.New("packageID is not found")
}
svc.hubClient.UserId = int(userID)
svc.hubClient.BucketID = int(pkg.BucketID)
upload := state2.DataUpload{
UserID: userID,
PackageID: packageID,
BlockChainToken: blockChainToken,
Task: task,
HubClient: svc.hubClient,
DataType: params.DataType,
UploadInfo: params.UploadInfo,
}
var jobs []jobmgr.SubmittingJob
jo := job.NewNormalJob(schsdk.NormalJobInfo{})
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state2.NewDataUpload(&upload),
})
jobSetID := svc.jobMgr.SubmitJobSet(jobs)
return &jobSetID, nil
}
// Submit 提交任务集
func (svc *JobSetService) Submit(userID cdssdk.UserID, jobSet schsdk.JobSetInfo, schScheme *jobmod.JobSetPreScheduleScheme, token string) (*schsdk.JobSetID, error) {
logger.Debugf("submitting job")
var jobs []jobmgr.SubmittingJob
for _, jobInfo := range jobSet.Jobs {
switch info := jobInfo.(type) {
case *schsdk.PCMJobInfo:
jo := job.NewPCMJob(*info)
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
//InitState: state.NewPreSchuduling(preSch),
InitState: state2.NewPCMJobCreate(userID, info, token, svc.hubClient),
})
case *schsdk.NormalJobInfo:
jo := job.NewNormalJob(*info)
jo.SubType = schsdk.JobTypeNormal
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
//InitState: state.NewPreSchuduling(preSch),
})
case *schsdk.DataReturnJobInfo:
jo := job.NewDataReturnJob(*info)
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state.NewWaitTargetComplete(),
})
case *schsdk.MultiInstanceJobInfo:
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
jo := job.NewMultiInstanceJob(*info, preSch)
if !ok {
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
}
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state.NewMultiInstanceInit(),
})
case *schsdk.UpdateMultiInstanceJobInfo:
modelJob := job.NewUpdateMultiInstanceJob(*info)
instanceJobSets := svc.jobMgr.DumpJobSet(modelJob.Info.MultiInstanceJobSetID)
if len(instanceJobSets) == 0 {
return nil, errors.New(fmt.Sprintf("job set %s is not found", modelJob.Info.MultiInstanceJobSetID))
}
// 找到多实例任务本身
var multiInstanceJobDump jobmod.JobDump
for i := 0; i < len(instanceJobSets); i++ {
jobDump := instanceJobSets[i]
if _, ok := jobDump.Body.(*jobmod.MultiInstanceJobDump); ok {
multiInstanceJobDump = jobDump
break
}
}
jobs = append(jobs, jobmgr.SubmittingJob{
Body: modelJob,
InitState: state.NewMultiInstanceUpdate(multiInstanceJobDump),
})
case *schsdk.DataPreprocessJobInfo:
// 后续的调度流程跟NormalJob是一致的
normalJobInfo := &schsdk.NormalJobInfo{
Type: schsdk.JobTypeNormal,
JobInfoBase: info.JobInfoBase,
Files: info.Files,
Runtime: info.Runtime,
Services: info.Services,
Resources: info.Resources,
}
jo := job.NewNormalJob(*normalJobInfo)
jo.SubType = schsdk.JobTypeDataPreprocess
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
if !ok {
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
}
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state.NewPreSchuduling(preSch),
})
case *schsdk.FinetuningJobInfo:
// 后续的调度流程跟NormalJob是一致的
normalJobInfo := &schsdk.NormalJobInfo{
Type: schsdk.JobTypeNormal,
Files: info.Files,
JobInfoBase: info.JobInfoBase,
Runtime: info.Runtime,
Services: info.Services,
Resources: info.Resources,
ModelJobInfo: info.ModelJobInfo,
}
jo := job.NewNormalJob(*normalJobInfo)
jo.SubType = schsdk.JobTypeFinetuning
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
if !ok {
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
}
jobs = append(jobs, jobmgr.SubmittingJob{
Body: jo,
InitState: state.NewPreSchuduling(preSch),
})
}
}
jobSetID := svc.jobMgr.SubmitJobSet(jobs)
return &jobSetID, nil
}
// LocalFileUploaded 任务集中某个文件上传完成
func (svc *JobSetService) LocalFileUploaded(jobSetID schsdk.JobSetID, localPath string, errMsg string, packageID cdssdk.PackageID, objectIDs []cdssdk.ObjectID) {
err := errors.New(errMsg)
svc.jobMgr.BroadcastEvent(jobSetID, event.NewLocalFileUploaded(localPath, err, packageID, objectIDs))
}
func (svc *JobSetService) CreateFolder(packageID cdssdk.PackageID, path string) error {
err := svc.JobSetSvc().db.UploadData().InsertFolder(svc.db.DefCtx(), packageID, path)
if err != nil {
return err
}
return nil
}
// 删除文件或文件夹
func (svc *JobSetService) DeleteFile(userID cdssdk.UserID, objectIDs []cdssdk.ObjectID) error {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new scheduler client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
err = cdsCli.Object().Delete(cdsapi.ObjectDelete{
ObjectIDs: objectIDs,
UserID: userID,
})
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
return nil
}
func (svc *JobSetService) DeleteFolder(userID cdssdk.UserID, packageID cdssdk.PackageID, path string) error {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new scheduler client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
list, err := cdsCli.Object().ListByPath(cdsapi.ObjectListByPath{
UserID: userID,
PackageID: packageID,
Path: path,
IsPrefix: true,
})
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
if len(list.Objects) > 0 {
var objectIDs []cdssdk.ObjectID
for _, obj := range list.Objects {
objectIDs = append(objectIDs, obj.ObjectID)
}
err = cdsCli.Object().Delete(cdsapi.ObjectDelete{
ObjectIDs: objectIDs,
UserID: userID,
})
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
}
err = svc.JobSetSvc().db.UploadData().DeleteFolder(svc.db.DefCtx(), packageID, path)
if err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
return nil
}
func (svc *JobSetService) QueryUploaded(queryParams sch.QueryData) ([]uploadersdk.Package, int, int, error) {
// 查询根目录
if queryParams.PackageID == -1 {
packages, err := svc.JobSetSvc().db.UploadData().QueryPackage(svc.db.DefCtx(), queryParams)
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
}
return packages, 0, 0, nil
}
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return nil, 0, 0, fmt.Errorf("new scheduler client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
queryListReq := cdsapi.ObjectListByPath{
UserID: queryParams.UserID,
PackageID: queryParams.PackageID,
Path: queryParams.Path,
IsPrefix: true,
}
objList, err := cdsCli.Object().ListByPath(queryListReq)
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
}
folderMap := make(map[string]cdssdk.Object)
var modifyObjs []cdssdk.Object
for _, obj := range objList.Objects {
// 去掉obj中path从0到queryParams.Path这段字符串
obj.Path = strings.TrimPrefix(obj.Path, queryParams.Path)
pathArr := strings.Split(obj.Path, "/")
if len(pathArr) > 2 {
splitPath := "/" + pathArr[1]
folderMap[splitPath] = cdssdk.Object{
ObjectID: -1,
PackageID: obj.PackageID,
Path: pathArr[1],
Size: 0,
CreateTime: obj.CreateTime,
}
continue
}
modifyObjs = append(modifyObjs, obj)
}
folders, err := svc.db.UploadData().QueryFolder(svc.db.DefCtx(), queryParams)
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
}
for _, folder := range folders {
folder.Path = strings.TrimPrefix(folder.Path, queryParams.Path)
if folder.Path == "" {
continue
}
folderMap[folder.Path] = cdssdk.Object{
ObjectID: -1,
PackageID: folder.PackageID,
Path: folder.Path,
Size: 0,
CreateTime: folder.CreateTime,
}
}
// 遍历folderMap将folderMap的值赋给objList.Objects
for _, obj := range folderMap {
modifyObjs = append(modifyObjs, obj)
}
objList.Objects = modifyObjs
// 根据orderBy字段排序
sort.Slice(objList.Objects, func(i, j int) bool {
if queryParams.OrderBy == sch.OrderByName {
return objList.Objects[i].Path < objList.Objects[j].Path
} else if queryParams.OrderBy == sch.OrderBySize {
return objList.Objects[i].Size < objList.Objects[j].Size
} else if queryParams.OrderBy == sch.OrderByTime {
return objList.Objects[i].CreateTime.Unix() < objList.Objects[j].CreateTime.Unix()
}
return false
})
totalNum := len(objList.Objects)
// 分页返回
if queryParams.PageSize > 0 {
start := (queryParams.CurrentPage - 1) * queryParams.PageSize
end := start + queryParams.PageSize
if start > totalNum {
return nil, 0, 0, nil
}
if end > totalNum {
end = totalNum
}
objList.Objects = objList.Objects[start:end]
}
totalPages := totalNum / queryParams.PageSize
var datas []uploadersdk.Package
data, err := svc.db.UploadData().QueryPackageByID(svc.db.DefCtx(), queryParams.PackageID)
if err != nil {
return nil, 0, 0, err
}
pkg := uploadersdk.Package{
PackageID: data.PackageID,
BucketID: data.BucketID,
DataType: data.DataType,
PackageName: data.PackageName,
JsonData: data.JsonData,
BindingID: data.BindingID,
UserID: data.UserID,
CreateTime: data.CreateTime,
Objects: objList.Objects,
UploadedCluster: data.UploadedCluster,
}
datas = append(datas, pkg)
return datas, totalPages, totalNum, nil
}
func (svc *JobSetService) DataBinding(id uploadersdk.DataID, userID cdssdk.UserID, info sch.DataBinding) error {
var bindingData uploadersdk.Binding
var bindingClusters []uploadersdk.BindingCluster
var packageID cdssdk.PackageID
switch bd := info.(type) {
case *sch.DatasetBinding:
pkg, clusterMap, err := svc.queryAndVerify(bd.PackageID, bd.ClusterIDs, userID, bd.Name, sch.DATASET)
if err != nil {
return err
}
param := dataScheduleParam{
BindingClusterIDs: bd.ClusterIDs,
ClusterMap: clusterMap,
PackageID: bd.PackageID,
UploadedClusters: pkg.UploadedCluster,
UserID: userID,
}
// 如果目标集群没有数据,则需要尽心数据调度
err = svc.dataSchedule(param)
if err != nil {
return err
}
for _, clusterID := range bd.ClusterIDs {
clusterInfo, ok := clusterMap[clusterID]
if !ok {
return fmt.Errorf("cluster %s not found", clusterID)
}
filePath := clusterInfo.StoragePath + sch.Split + sch.DATASET + sch.Split + pkg.PackageName
datasetParam := &dataset.Dataset{
Name: bd.Name,
Description: bd.Description,
FilePath: filePath,
Category: dataset.CommonValue(bd.Category),
}
var jsonData []byte
switch clusterInfo.ClusterName {
case sch.PlatformModelArts:
resp, err := svc.hubClient.ModelArts.BindDataset(datasetParam)
if err != nil {
return err
}
jsonData, err = json.Marshal(resp.Data)
if err != nil {
return err
}
case sch.PlatformOpenI:
datasetParam.RepoName = ""
resp, err := svc.hubClient.OpenI.BindDataset(datasetParam)
if err != nil {
return err
}
jsonData, err = json.Marshal(resp.Data)
if err != nil {
return err
}
}
bindingClusters = append(bindingClusters, uploadersdk.BindingCluster{
ClusterID: uploadersdk.ClusterID(clusterID),
Status: sch.SuccessStatus,
JsonData: string(jsonData),
})
}
content, err := json.Marshal(bd)
if err != nil {
return err
}
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content))
packageID = bd.PackageID
case *sch.CodeBinding:
pkg, clusterMap, err := svc.queryAndVerify(bd.PackageID, bd.ClusterIDs, userID, bd.Name, sch.CODE)
if err != nil {
return err
}
param := dataScheduleParam{
BindingClusterIDs: bd.ClusterIDs,
ClusterMap: clusterMap,
PackageID: bd.PackageID,
UploadedClusters: pkg.UploadedCluster,
UserID: userID,
}
// 如果目标集群没有数据,则需要尽心数据调度
err = svc.dataSchedule(param)
if err != nil {
return err
}
for _, clusterID := range bd.ClusterIDs {
clusterInfo, ok := clusterMap[clusterID]
if !ok {
return fmt.Errorf("cluster %s not found", clusterID)
}
filePath := clusterInfo.StoragePath + sch.Split + sch.CODE + sch.Split + pkg.PackageName
codeParam := &algorithm.Algorithm{
Name: bd.Name,
Description: bd.Description,
CodeDir: filePath,
//Engine: "",
BootFile: "",
}
var jsonData []byte
switch clusterInfo.ClusterName {
case sch.PlatformModelArts:
engine := algorithm.Engine{
EngineName: "Ascend-Powered-Engine",
EngineVersion: "", // 从URL里提取
ImageUrl: "", // 从数据库中获取
InstallSysPackages: true,
}
println(engine)
resp, err := svc.hubClient.ModelArts.BindAlgorithm(codeParam)
if err != nil {
return err
}
jsonData, err = json.Marshal(resp.Data)
if err != nil {
return err
}
case sch.PlatformOpenI:
//openi不需要传Engine
resp, err := svc.hubClient.OpenI.BindAlgorithm(codeParam)
if err != nil {
return err
}
err = svc.hubClient.UploadOpenIAlgorithm(codeParam, int(bd.PackageID))
if err != nil {
return err
}
jsonData, err = json.Marshal(resp.Data)
if err != nil {
return err
}
}
bindingClusters = append(bindingClusters, uploadersdk.BindingCluster{
ClusterID: uploadersdk.ClusterID(clusterID),
Status: sch.SuccessStatus,
JsonData: string(jsonData),
})
}
content, err := json.Marshal(bd)
if err != nil {
return err
}
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content))
packageID = bd.PackageID
case *sch.ImageBinding:
//content, err := json.Marshal(bd)
//if err != nil {
// return err
//}
//bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content), "")
//packageID = bd.PackageID
return fmt.Errorf("not support image binding")
case *sch.ModelBinding:
pkg, clusterMap, err := svc.queryAndVerify(bd.PackageID, bd.ClusterIDs, userID, bd.Name, sch.MODEL)
if err != nil {
return err
}
param := dataScheduleParam{
BindingClusterIDs: bd.ClusterIDs,
ClusterMap: clusterMap,
PackageID: bd.PackageID,
UploadedClusters: pkg.UploadedCluster,
UserID: userID,
}
// 如果目标集群没有数据,则需要尽心数据调度
err = svc.dataSchedule(param)
if err != nil {
return err
}
for _, clusterID := range bd.ClusterIDs {
clusterInfo, ok := clusterMap[clusterID]
if !ok {
return fmt.Errorf("cluster %s not found", clusterID)
}
filePath := clusterInfo.StoragePath + sch.Split + sch.MODEL + sch.Split + pkg.PackageName
modelParam := &model.Model{
Name: bd.Name,
Description: bd.Description,
Type: bd.ModelType,
FilePath: filePath,
//Engine: model.CommonValue(),
//Version:
}
var jsonData []byte
switch clusterInfo.ClusterName {
case sch.PlatformModelArts:
resp, err := svc.hubClient.ModelArts.BindModel(modelParam)
if err != nil {
return err
}
jsonData, err = json.Marshal(resp.Data)
if err != nil {
return err
}
case sch.PlatformOpenI:
modelParam.RepoName = ""
resp, err := svc.hubClient.OpenI.BindModel(modelParam)
if err != nil {
return err
}
jsonData, err = json.Marshal(resp.Data)
if err != nil {
return err
}
}
bindingClusters = append(bindingClusters, uploadersdk.BindingCluster{
ClusterID: uploadersdk.ClusterID(clusterID),
Status: sch.SuccessStatus,
JsonData: string(jsonData),
})
}
content, err := json.Marshal(bd)
if err != nil {
return err
}
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content))
packageID = bd.PackageID
}
if bindingData.AccessLevel == "" {
bindingData.AccessLevel = sch.PrivateAccess
}
bindingData.CreateTime = time.Now()
_, err := svc.db.UploadData().InsertOrUpdateBinding(svc.db.DefCtx(), bindingData, bindingClusters, packageID)
if err != nil {
return err
}
return nil
}
type dataScheduleParam struct {
UserID cdssdk.UserID
PackageID cdssdk.PackageID
BindingClusterIDs []schsdk.ClusterID
UploadedClusters []uploadersdk.Cluster
ClusterMap map[schsdk.ClusterID]uploadersdk.ClusterMapping
}
func (svc *JobSetService) dataSchedule(param dataScheduleParam) error {
// 筛选出需要数据调度的集群
var clusters []uploadersdk.ClusterMapping
for _, cid := range param.BindingClusterIDs {
isMatch := false
for _, cluster := range param.UploadedClusters {
if cid == cluster.ClusterID {
isMatch = true
break
}
}
if !isMatch {
clusterInfo, ok := param.ClusterMap[cid]
if ok {
clusters = append(clusters, clusterInfo)
}
}
}
// 进行数据调度
for _, cluster := range clusters {
// 数据调度
_, err := svc.hubClient.LoadPackage(uint(param.PackageID), uint(param.UserID), uint(cluster.StorageID), "")
if err != nil {
logger.Error("data schedule failed, error: ", err.Error())
return err
}
// 将调度成功的集群加入到uploadedCluster
cluster := uploadersdk.Cluster{
PackageID: param.PackageID,
ClusterID: cluster.ClusterID,
StorageID: cluster.StorageID,
}
err = svc.db.UploadData().InsertUploadedCluster(svc.db.DefCtx(), cluster)
if err != nil {
logger.Error("insert uploadedCluster failed, error: ", err.Error())
return err
}
}
return nil
}
func (svc *JobSetService) queryAndVerify(pacakgeID cdssdk.PackageID, clusterIDs []schsdk.ClusterID, userID cdssdk.UserID, name string, dataType string) (*uploadersdk.PackageDAO, map[schsdk.ClusterID]uploadersdk.ClusterMapping, error) {
// 查询是否已经绑定
existBinding, err := svc.db.UploadData().GetBindingByName(svc.db.DefCtx(), userID, name, dataType)
if err != nil {
return nil, nil, err
}
if existBinding.ID != 0 {
return nil, nil, fmt.Errorf("name %s already exists", name)
}
// 查询package
pkg, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), pacakgeID)
if err != nil {
return nil, nil, err
}
if pkg.PackageID == 0 {
return nil, nil, fmt.Errorf("no package found")
}
// 如果这个package已经被绑定则不允许再绑定
if pkg.BindingID != -1 {
binding, err := svc.db.UploadData().GetBindingsByID(svc.db.DefCtx(), pkg.BindingID)
if err != nil {
return nil, nil, err
}
return nil, nil, fmt.Errorf("binding already exists, name: " + binding.Name)
}
clusterMap := make(map[schsdk.ClusterID]uploadersdk.ClusterMapping)
clusters, err := svc.db.UploadData().GetClusterByID(svc.db.DefCtx(), clusterIDs)
if err != nil {
return nil, nil, err
}
for _, cluster := range clusters {
clusterMap[cluster.ClusterID] = cluster
}
return &pkg, clusterMap, nil
}
func getBindingData(id uploadersdk.DataID, userID cdssdk.UserID, dataType string, name string, content string) uploadersdk.Binding {
bindingData := uploadersdk.Binding{
ID: id,
Name: name,
DataType: dataType,
UserID: userID,
Content: content,
//JsonData: jsonData,
}
return bindingData
}
func (svc *JobSetService) RemoveBinding(pacakgeIDs []cdssdk.PackageID, bindingIDs []int64) error {
if len(pacakgeIDs) > 0 {
for _, id := range pacakgeIDs {
pkgDao := uploadersdk.PackageDAO{
BindingID: -1,
}
err := svc.db.UploadData().UpdatePackage(svc.db.DefCtx(), id, pkgDao)
if err != nil {
return err
}
}
}
if len(bindingIDs) > 0 {
err := svc.db.UploadData().DeleteBindingsByID(svc.db.DefCtx(), bindingIDs)
if err != nil {
return err
}
}
return nil
}
func (svc *JobSetService) DeleteBinding(IDs []int64) error {
err := svc.db.UploadData().DeleteBindingsByID(svc.db.DefCtx(), IDs)
if err != nil {
return err
}
return nil
}
func (svc *JobSetService) QueryBinding(dataType string, param sch.QueryBindingDataParam, filters sch.QueryBindingFilters) ([]uploadersdk.BindingDetail, error) {
switch p := param.(type) {
case *sch.PrivateLevel:
return svc.queryPrivateBinding(p.UserID, uploadersdk.DataID(p.BindingID), dataType)
case *sch.PublicLevel:
var details []uploadersdk.BindingDetail
datas, err := svc.db.UploadData().GetPublicBindings(svc.db.DefCtx(), p.Type, dataType, p.UserID)
if err != nil {
return nil, err
}
for _, data := range datas {
var info sch.DataBinding
binding := uploadersdk.Binding{
DataType: dataType,
Content: data.Content,
}
info, err = unmarshalBinding(binding)
if err != nil {
return nil, err
}
bindingDetail := uploadersdk.BindingDetail{
ID: data.ID,
UserID: data.UserID,
UserName: data.UserName,
SSOId: data.SSOId,
Name: data.Name,
Info: info,
AccessLevel: data.AccessLevel,
CreateTime: data.CreateTime,
}
details = append(details, bindingDetail)
}
return details, nil
case *sch.ApplyLevel:
var details []uploadersdk.BindingDetail
datas, err := svc.db.UploadData().GetApplyBindings(svc.db.DefCtx(), p.UserID, p.Type, dataType)
if err != nil {
return nil, err
}
for _, data := range datas {
var info sch.DataBinding
// 如果有指定状态,则只展示指定状态的数据
if filters.Status != "" && data.Status != filters.Status {
continue
}
if filters.Name != "" && data.Name != filters.Name {
continue
}
// 只有approved状态的数据才能看到详情
if data.Status == sch.ApprovedStatus {
binding := uploadersdk.Binding{
DataType: dataType,
Content: data.Content,
}
info, err = unmarshalBinding(binding)
if err != nil {
return nil, err
}
}
bindingDetail := uploadersdk.BindingDetail{
ID: data.ID,
UserID: data.UserID,
UserName: data.UserName,
SSOId: data.SSOId,
Name: data.Name,
Info: info,
Status: data.Status,
AccessLevel: data.AccessLevel,
CreateTime: data.CreateTime,
}
details = append(details, bindingDetail)
}
return details, nil
}
return nil, fmt.Errorf("unknown query binding data type")
}
func (svc *JobSetService) queryPrivateBinding(userID cdssdk.UserID, bindingID uploadersdk.DataID, dataType string) ([]uploadersdk.BindingDetail, error) {
var details []uploadersdk.BindingDetail
if bindingID == -1 {
datas, err := svc.db.UploadData().GetPrivateBindings(svc.db.DefCtx(), userID, dataType)
if err != nil {
return nil, err
}
for _, data := range datas {
info, err := unmarshalBinding(data)
if err != nil {
return nil, err
}
bindingDetail := uploadersdk.BindingDetail{
ID: data.ID,
UserID: data.UserID,
Info: info,
AccessLevel: data.AccessLevel,
CreateTime: data.CreateTime,
}
details = append(details, bindingDetail)
}
return details, nil
}
data, err := svc.db.UploadData().GetBindingsByID(svc.db.DefCtx(), bindingID)
if err != nil {
return nil, err
}
info, err := unmarshalBinding(*data)
if err != nil {
return nil, err
}
packages, err := svc.db.UploadData().QueryPackageByBindingID(svc.db.DefCtx(), bindingID)
if err != nil {
return nil, err
}
detail := uploadersdk.BindingDetail{
ID: bindingID,
UserID: data.UserID,
Packages: packages,
Info: info,
AccessLevel: data.AccessLevel,
CreateTime: data.CreateTime,
}
details = append(details, detail)
return details, nil
}
func unmarshalBinding(data uploadersdk.Binding) (sch.DataBinding, error) {
var info sch.DataBinding
switch data.DataType {
case sch.DATASET:
var content sch.DatasetBinding
err := json.Unmarshal([]byte(data.Content), &content)
if err != nil {
return nil, err
}
info = &content
case sch.CODE:
var content sch.CodeBinding
err := json.Unmarshal([]byte(data.Content), &content)
if err != nil {
return nil, err
}
info = &content
case sch.IMAGE:
var content sch.ImageBinding
err := json.Unmarshal([]byte(data.Content), &content)
if err != nil {
return nil, err
}
info = &content
case sch.MODEL:
var content sch.ModelBinding
err := json.Unmarshal([]byte(data.Content), &content)
if err != nil {
return nil, err
}
info = &content
}
return info, nil
}
func (svc *JobSetService) CreatePackage(userID cdssdk.UserID, name string, dataType string, uploadPriority sch.UploadPriority) error {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new cds client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
bucket, err := svc.db.Access().GetBucketByUserID(svc.db.DefCtx(), userID, dataType)
if err != nil {
return fmt.Errorf("failed to get bucket: %w", err)
} else if bucket == nil {
return fmt.Errorf("bucket not found")
}
// 创建package
newPackage, err := cdsCli.Package().Create(cdsapi.PackageCreate{
UserID: userID,
BucketID: bucket.ID,
Name: name,
})
if err != nil {
return fmt.Errorf("failed to create package: %w", err)
}
pkg := uploadersdk.Package{
UserID: userID,
PackageID: newPackage.Package.PackageID,
PackageName: name,
BucketID: bucket.ID,
DataType: dataType,
UploadPriority: uploadPriority,
}
// 对Package进行预调度并写入到数据库中
clusters, err := svc.packageScheduler(pkg.PackageID, uploadPriority)
if err != nil {
return err
}
// 写入数据库存档
err = svc.JobSetSvc().db.UploadData().InsertPackage(svc.db.DefCtx(), pkg, clusters)
if err != nil {
return err
}
return nil
}
func (svc *JobSetService) packageScheduler(packageID cdssdk.PackageID, uploadPriority sch.UploadPriority) ([]uploadersdk.Cluster, error) {
clusterMapping, err := svc.db.UploadData().GetClusterMapping(svc.db.DefCtx())
if err != nil {
return nil, fmt.Errorf("query cluster mapping error: %w", err)
}
var clusters []uploadersdk.Cluster
switch uploadPriority := uploadPriority.(type) {
case *sch.Preferences:
// 进行预调度
clusterID, err := svc.preScheduler.ScheduleJob(uploadPriority.ResourcePriorities, clusterMapping)
if err != nil {
return nil, fmt.Errorf("pre scheduling: %w", err)
}
storageID, ok := clusterMapping[*clusterID]
if !ok {
return nil, fmt.Errorf("cluster %d not found", clusterID)
}
cluster := uploadersdk.Cluster{
PackageID: packageID,
ClusterID: *clusterID,
StorageID: storageID,
}
clusters = append(clusters, cluster)
case *sch.SpecifyCluster:
// 指定集群
for _, clusterID := range uploadPriority.Clusters {
storageID, ok := clusterMapping[clusterID]
if !ok {
logger.Warnf("cluster %d not found", clusterID)
continue
}
cluster := uploadersdk.Cluster{
PackageID: packageID,
ClusterID: clusterID,
StorageID: storageID,
}
clusters = append(clusters, cluster)
}
}
if len(clusters) == 0 {
return nil, errors.New("no storage is available")
}
//for _, clst := range clusters {
// err := svc.db.UploadData().InsertUploadedCluster(svc.db.DefCtx(), clst)
// if err != nil {
// return nil, err
// }
//}
return clusters, nil
}
func (svc *JobSetService) DeletePackage(userID cdssdk.UserID, packageID cdssdk.PackageID) error {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new cds client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
err = cdsCli.Package().Delete(cdsapi.PackageDelete{
UserID: userID,
PackageID: packageID,
})
if err != nil {
return fmt.Errorf("delete package: %w", err)
}
err = svc.JobSetSvc().db.UploadData().DeletePackage(svc.db.DefCtx(), userID, packageID)
if err != nil {
return err
}
return nil
}
func (svc *JobSetService) QueryResource(queryResource sch.ResourceRange) ([]sch.ClusterDetail, error) {
clusterDetails, err := svc.getClusterResources()
if err != nil {
return nil, err
}
var results []sch.ClusterDetail
for _, cluster := range clusterDetails {
if cluster.Resources == nil || len(cluster.Resources) == 0 {
continue
}
ok := isAppropriateResources(cluster.Resources, queryResource)
if ok {
results = append(results, cluster)
}
}
return results, nil
}
func isAppropriateResources(resources []sch.ClusterResource, queryResource sch.ResourceRange) bool {
for _, resource := range resources {
if resource.Resource.Type == queryResource.Type {
//ok := compareResource(queryResource.GPU.Min, queryResource.GPU.Max, resource.Resource.Available.Value)
//if !ok {
// return false
//}
if resource.BaseResources == nil || len(resource.BaseResources) == 0 {
return false
}
ok := false
for _, baseResource := range resource.BaseResources {
if baseResource.Type == sch.ResourceTypeCPU {
ok = compareResource(queryResource.CPU.Min, queryResource.CPU.Max, baseResource.Available.Value)
if !ok {
return false
}
}
if baseResource.Type == sch.ResourceTypeMemory {
ok = compareResource(queryResource.Memory.Min, queryResource.Memory.Max, baseResource.Available.Value)
if !ok {
return false
}
}
if baseResource.Type == sch.ResourceTypeStorage {
ok = compareResource(queryResource.Storage.Min, queryResource.Storage.Max, baseResource.Available.Value)
if !ok {
return false
}
}
}
return true
}
}
return false
}
func compareResource(min float64, max float64, v float64) bool {
if min > max {
return false
}
if min == 0 && max == 0 {
return true
}
if v >= min && v <= max {
return true
}
return false
}
func (svc *JobSetService) ResourceRange() ([]sch.ResourceRange, error) {
clusterDetails, err := svc.getClusterResources()
if err != nil {
return nil, err
}
// 初始化一个空的 map 来存储资源类型的 Range 数据
resourceMap := make(map[sch.ResourceType]sch.ResourceRange)
// 遍历所有 ClusterDetail
for _, cluster := range clusterDetails {
var CPUValue float64
var MemValue float64
var StorageValue float64
for _, resource := range cluster.Resources {
for _, baseResource := range resource.BaseResources {
// 检查资源类型,跳过不需要统计的类型
switch baseResource.Type {
case sch.ResourceTypeCPU:
CPUValue = baseResource.Available.Value
case sch.ResourceTypeMemory:
MemValue = baseResource.Available.Value
case sch.ResourceTypeStorage:
StorageValue = baseResource.Available.Value
}
}
}
// 遍历每个 ClusterDetail 的资源列表
for _, resource := range cluster.Resources {
// 获取资源类型的 key
resourceType := resource.Resource.Type
// 获取资源的 Available Value
//availableValue := resource.Available.Value
// 获取现有的 ResourceRange
resourceRange, exists := resourceMap[resourceType]
if !exists {
// 如果该资源类型还没有添加过,初始化一个新的 Range
resourceRange = sch.ResourceRange{
Type: resourceType,
GPU: sch.Range{}, // 初始的 GPU 范围
GPUNumber: 0, // 初始的 GPU 数量
CPU: sch.Range{}, // 初始的 CPU 范围
Memory: sch.Range{}, // 初始的 Memory 范围
Storage: sch.Range{}, // 初始的 Storage 范围
}
}
if CPUValue < resourceRange.CPU.Min || resourceRange.CPU.Min == 0 {
resourceRange.CPU.Min = CPUValue
}
if CPUValue > resourceRange.CPU.Max {
resourceRange.CPU.Max = CPUValue
}
if MemValue < resourceRange.Memory.Min || resourceRange.Memory.Min == 0 {
resourceRange.Memory.Min = MemValue
}
if MemValue > resourceRange.Memory.Max {
resourceRange.Memory.Max = MemValue
}
if StorageValue < resourceRange.Storage.Min || resourceRange.Storage.Min == 0 {
resourceRange.Storage.Min = StorageValue
}
if StorageValue > resourceRange.Storage.Max {
resourceRange.Storage.Max = StorageValue
}
// 增加资源数量统计
resourceRange.GPUNumber++
// 更新 resourceMap 中对应资源类型的 ResourceRange
resourceMap[resourceType] = resourceRange
}
}
// 将 map 转换为一个 slice
var result []sch.ResourceRange
for _, rangeData := range resourceMap {
result = append(result, rangeData)
}
// 返回统计结果
return result, nil
}
func (svc *JobSetService) getClusterResources() ([]sch.ClusterDetail, error) {
schCli, err := schglb.PCMSchePool.Acquire()
if err != nil {
return nil, fmt.Errorf("new scheduler client: %w", err)
}
defer schglb.PCMSchePool.Release(schCli)
clusterMapping, err := svc.db.UploadData().GetClusterMapping(svc.db.DefCtx())
if err != nil {
return nil, fmt.Errorf("query cluster mapping: %w", err)
}
// 查询指定算力中心
clusterIDs := make([]schsdk.ClusterID, 0, len(clusterMapping))
for id := range clusterMapping {
clusterIDs = append(clusterIDs, id)
}
clusterDetails, err := schCli.GetClusterInfo(sch.GetClusterInfoReq{
IDs: clusterIDs,
})
if err != nil {
return nil, fmt.Errorf("get cluster info: %w", err)
}
if len(clusterDetails) == 0 {
return nil, errors.New("no cluster found")
}
return clusterDetails, nil
}
func (svc *JobSetService) QueryImages(IDs []int64) ([]sch.ClusterImage, error) {
images, err := svc.db.UploadData().GetImageByID(svc.db.DefCtx(), IDs)
if err != nil {
return nil, fmt.Errorf("query images: %w", err)
}
return images, nil
}
func (svc *JobSetService) ClonePackage(userID cdssdk.UserID, param uploadersdk.PackageCloneParam, cloneType string) (*cdssdk.Package, error) {
clonePackageID := param.PackageID
pkg := cdssdk.Package{
PackageID: param.PackageID,
Name: param.PackageName,
//BucketID: param.BucketID,
}
parentPkg, err := svc.db.UploadData().GetParentClonePackageByPkgID(svc.db.DefCtx(), param.PackageID)
if err != nil {
return nil, fmt.Errorf("query clone package: %w", err)
}
// 如果子算法列表已经存在数据则执行clone操作新增子算法
switch cloneType {
case sch.ChildrenType:
// 判断父算法是否已经创建
if parentPkg.ParentPackageID == 0 {
return nil, fmt.Errorf("code parent package is not exists")
}
// 查询package用于获取bucketID
queryPkg, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), param.PackageID)
if err != nil {
return nil, fmt.Errorf("query package: %w", err)
}
// 复制package
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return nil, fmt.Errorf("new cds client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
version := strconv.FormatInt(time.Now().Unix(), 10)
packageName := fmt.Sprintf("%s_%s", param.PackageName, version)
cloneReq := cdsapi.PackageClone{
PackageID: param.PackageID,
Name: packageName,
BucketID: queryPkg.BucketID,
UserID: userID,
}
cloneResp, err := cdsCli.Package().Clone(cloneReq)
if err != nil {
return nil, fmt.Errorf("clone package: %w", err)
}
clonePackageID = cloneResp.Package.PackageID
pkg = cloneResp.Package
case sch.ParentType:
// 判断父算法是否已经创建
if parentPkg.ParentPackageID != 0 {
return nil, fmt.Errorf("code parent package alread exists")
}
}
packageCloneDAO := uploadersdk.PackageCloneDAO{
ParentPackageID: param.PackageID,
ClonePackageID: clonePackageID,
Name: param.Name,
Description: param.Description,
BootstrapObjectID: param.BootstrapObjectID,
ParentImageID: param.ParentImageID,
ImageID: param.ImageID,
ClusterID: param.ClusterID,
CreateTime: time.Now(),
}
// 将package添加到version表
err = svc.db.UploadData().InsertClonePackage(svc.db.DefCtx(), packageCloneDAO)
if err != nil {
return nil, fmt.Errorf("insert package version: %w", err)
}
// 返回package
return &pkg, nil
}
func (svc *JobSetService) QueryClonePackage(packageID cdssdk.PackageID, userID cdssdk.UserID, dataType string) ([]uploadersdk.PackageCloneVO, error) {
// 获取父算法列表
if packageID == -1 {
pkgs, err := svc.db.UploadData().GetCloneParentPackage(svc.db.DefCtx(), userID, dataType)
if err != nil {
return nil, fmt.Errorf("query clone parent package: %w", err)
}
return pkgs, nil
}
// 获取子算法
pkgs, err := svc.db.UploadData().GetChildrenClonePackageByPkgID(svc.db.DefCtx(), packageID)
if err != nil {
return nil, fmt.Errorf("query children package: %w", err)
}
return pkgs, nil
}
func (svc *JobSetService) RemoveClonePackage(userID cdssdk.UserID, cloneType string, packageIDs []cdssdk.PackageID) error {
if cloneType == sch.ParentType {
err := svc.db.UploadData().RemoveClonePackage(svc.db.DefCtx(), packageIDs, []cdssdk.PackageID{})
if err != nil {
return fmt.Errorf("remove clone package: %w", err)
}
} else {
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
if err != nil {
return fmt.Errorf("new cds client: %w", err)
}
defer schglb.CloudreamStoragePool.Release(cdsCli)
for _, id := range packageIDs {
packageDelete := cdsapi.PackageDelete{
UserID: userID,
PackageID: id,
}
err = cdsCli.Package().Delete(packageDelete)
if err != nil {
return fmt.Errorf("delete package: %w", err)
}
}
err = svc.db.UploadData().RemoveClonePackage(svc.db.DefCtx(), []cdssdk.PackageID{}, packageIDs)
if err != nil {
return fmt.Errorf("remove clone package: %w", err)
}
}
return nil
}