1446 lines
40 KiB
Go
1446 lines
40 KiB
Go
package services
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/algorithm"
|
||
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/dataset"
|
||
"gitlink.org.cn/JointCloud/pcm-hub/aikit/common/model"
|
||
jobTask "gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/task"
|
||
"sort"
|
||
"strconv"
|
||
|
||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||
sch "gitlink.org.cn/cloudream/common/sdks/pcmscheduler"
|
||
"gitlink.org.cn/cloudream/common/sdks/storage/cdsapi"
|
||
uploadersdk "gitlink.org.cn/cloudream/common/sdks/uploader"
|
||
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
|
||
jobmod "gitlink.org.cn/cloudream/scheduler/common/models/job"
|
||
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr"
|
||
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/event"
|
||
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job"
|
||
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job/state"
|
||
"gitlink.org.cn/cloudream/scheduler/schedulerMiddleware/internal/manager/jobmgr/job/state2"
|
||
"strings"
|
||
"time"
|
||
|
||
schsdk "gitlink.org.cn/cloudream/common/sdks/scheduler"
|
||
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
|
||
)
|
||
|
||
type JobSetService struct {
|
||
*Service
|
||
}
|
||
|
||
func (svc *Service) JobSetSvc() *JobSetService {
|
||
return &JobSetService{Service: svc}
|
||
}
|
||
|
||
func (svc *JobSetService) PreScheduler(jobSet schsdk.JobSetInfo) (*jobmod.JobSetPreScheduleScheme, *schsdk.JobSetFilesUploadScheme, error) {
|
||
ccs, err := svc.db.ComputingCenter().GetAll(svc.db.DefCtx())
|
||
if err != nil {
|
||
logger.Warnf("getting all computing center: %s", err.Error())
|
||
return nil, nil, err
|
||
}
|
||
|
||
schScheme, uploadScheme, err := svc.preScheduler.ScheduleJobSet(&jobSet, ccs)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("pre scheduling: %w", err)
|
||
}
|
||
|
||
return schScheme, uploadScheme, nil
|
||
}
|
||
|
||
func (svc *JobSetService) Upload(userID cdssdk.UserID, packageID cdssdk.PackageID, params sch.UploadParams, blockChainToken string, task *jobTask.JobTask[sch.TaskMessage]) (*schsdk.JobSetID, error) {
|
||
logger.Debugf("uploading job")
|
||
|
||
// 根据packageID 查询出对应的bucketID
|
||
pkg, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), packageID)
|
||
if err != nil {
|
||
logger.Warnf("getting upload data: %s", err.Error())
|
||
return nil, err
|
||
}
|
||
if pkg.PackageID == 0 {
|
||
return nil, errors.New("packageID is not found")
|
||
}
|
||
svc.hubClient.UserId = int(userID)
|
||
svc.hubClient.BucketID = int(pkg.BucketID)
|
||
|
||
upload := state2.DataUpload{
|
||
UserID: userID,
|
||
PackageID: packageID,
|
||
BlockChainToken: blockChainToken,
|
||
Task: task,
|
||
HubClient: svc.hubClient,
|
||
DataType: params.DataType,
|
||
UploadInfo: params.UploadInfo,
|
||
}
|
||
|
||
var jobs []jobmgr.SubmittingJob
|
||
jo := job.NewNormalJob(schsdk.NormalJobInfo{})
|
||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||
Body: jo,
|
||
InitState: state2.NewDataUpload(&upload),
|
||
})
|
||
|
||
jobSetID := svc.jobMgr.SubmitJobSet(jobs)
|
||
|
||
return &jobSetID, nil
|
||
}
|
||
|
||
// Submit 提交任务集
|
||
func (svc *JobSetService) Submit(userID cdssdk.UserID, jobSet schsdk.JobSetInfo, schScheme *jobmod.JobSetPreScheduleScheme, token string) (*schsdk.JobSetID, error) {
|
||
logger.Debugf("submitting job")
|
||
|
||
var jobs []jobmgr.SubmittingJob
|
||
for _, jobInfo := range jobSet.Jobs {
|
||
switch info := jobInfo.(type) {
|
||
case *schsdk.PCMJobInfo:
|
||
jo := job.NewPCMJob(*info)
|
||
|
||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||
Body: jo,
|
||
//InitState: state.NewPreSchuduling(preSch),
|
||
InitState: state2.NewPCMJobCreate(userID, info, token, svc.hubClient),
|
||
})
|
||
|
||
case *schsdk.NormalJobInfo:
|
||
jo := job.NewNormalJob(*info)
|
||
jo.SubType = schsdk.JobTypeNormal
|
||
|
||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||
Body: jo,
|
||
//InitState: state.NewPreSchuduling(preSch),
|
||
})
|
||
|
||
case *schsdk.DataReturnJobInfo:
|
||
jo := job.NewDataReturnJob(*info)
|
||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||
Body: jo,
|
||
InitState: state.NewWaitTargetComplete(),
|
||
})
|
||
|
||
case *schsdk.MultiInstanceJobInfo:
|
||
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
|
||
|
||
jo := job.NewMultiInstanceJob(*info, preSch)
|
||
|
||
if !ok {
|
||
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
|
||
}
|
||
|
||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||
Body: jo,
|
||
InitState: state.NewMultiInstanceInit(),
|
||
})
|
||
|
||
case *schsdk.UpdateMultiInstanceJobInfo:
|
||
modelJob := job.NewUpdateMultiInstanceJob(*info)
|
||
instanceJobSets := svc.jobMgr.DumpJobSet(modelJob.Info.MultiInstanceJobSetID)
|
||
if len(instanceJobSets) == 0 {
|
||
return nil, errors.New(fmt.Sprintf("job set %s is not found", modelJob.Info.MultiInstanceJobSetID))
|
||
}
|
||
|
||
// 找到多实例任务本身
|
||
var multiInstanceJobDump jobmod.JobDump
|
||
for i := 0; i < len(instanceJobSets); i++ {
|
||
jobDump := instanceJobSets[i]
|
||
if _, ok := jobDump.Body.(*jobmod.MultiInstanceJobDump); ok {
|
||
multiInstanceJobDump = jobDump
|
||
break
|
||
}
|
||
}
|
||
|
||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||
Body: modelJob,
|
||
InitState: state.NewMultiInstanceUpdate(multiInstanceJobDump),
|
||
})
|
||
|
||
case *schsdk.DataPreprocessJobInfo:
|
||
// 后续的调度流程跟NormalJob是一致的
|
||
normalJobInfo := &schsdk.NormalJobInfo{
|
||
Type: schsdk.JobTypeNormal,
|
||
JobInfoBase: info.JobInfoBase,
|
||
Files: info.Files,
|
||
Runtime: info.Runtime,
|
||
Services: info.Services,
|
||
Resources: info.Resources,
|
||
}
|
||
jo := job.NewNormalJob(*normalJobInfo)
|
||
jo.SubType = schsdk.JobTypeDataPreprocess
|
||
|
||
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
|
||
if !ok {
|
||
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
|
||
}
|
||
|
||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||
Body: jo,
|
||
InitState: state.NewPreSchuduling(preSch),
|
||
})
|
||
|
||
case *schsdk.FinetuningJobInfo:
|
||
// 后续的调度流程跟NormalJob是一致的
|
||
normalJobInfo := &schsdk.NormalJobInfo{
|
||
Type: schsdk.JobTypeNormal,
|
||
Files: info.Files,
|
||
JobInfoBase: info.JobInfoBase,
|
||
Runtime: info.Runtime,
|
||
Services: info.Services,
|
||
Resources: info.Resources,
|
||
ModelJobInfo: info.ModelJobInfo,
|
||
}
|
||
jo := job.NewNormalJob(*normalJobInfo)
|
||
jo.SubType = schsdk.JobTypeFinetuning
|
||
|
||
preSch, ok := schScheme.JobSchemes[info.LocalJobID]
|
||
if !ok {
|
||
return nil, errors.New(fmt.Sprintf("pre schedule scheme for job %s is not found", info.LocalJobID))
|
||
}
|
||
|
||
jobs = append(jobs, jobmgr.SubmittingJob{
|
||
Body: jo,
|
||
InitState: state.NewPreSchuduling(preSch),
|
||
})
|
||
}
|
||
}
|
||
|
||
jobSetID := svc.jobMgr.SubmitJobSet(jobs)
|
||
return &jobSetID, nil
|
||
}
|
||
|
||
// LocalFileUploaded 任务集中某个文件上传完成
|
||
func (svc *JobSetService) LocalFileUploaded(jobSetID schsdk.JobSetID, localPath string, errMsg string, packageID cdssdk.PackageID, objectIDs []cdssdk.ObjectID) {
|
||
|
||
err := errors.New(errMsg)
|
||
|
||
svc.jobMgr.BroadcastEvent(jobSetID, event.NewLocalFileUploaded(localPath, err, packageID, objectIDs))
|
||
}
|
||
|
||
func (svc *JobSetService) CreateFolder(packageID cdssdk.PackageID, path string) error {
|
||
err := svc.JobSetSvc().db.UploadData().InsertFolder(svc.db.DefCtx(), packageID, path)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 删除文件或文件夹
|
||
func (svc *JobSetService) DeleteFile(userID cdssdk.UserID, objectIDs []cdssdk.ObjectID) error {
|
||
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return fmt.Errorf("new scheduler client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(cdsCli)
|
||
|
||
err = cdsCli.Object().Delete(cdsapi.ObjectDelete{
|
||
ObjectIDs: objectIDs,
|
||
UserID: userID,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete object: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (svc *JobSetService) DeleteFolder(userID cdssdk.UserID, packageID cdssdk.PackageID, path string) error {
|
||
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return fmt.Errorf("new scheduler client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(cdsCli)
|
||
|
||
list, err := cdsCli.Object().ListByPath(cdsapi.ObjectListByPath{
|
||
UserID: userID,
|
||
PackageID: packageID,
|
||
Path: path,
|
||
IsPrefix: true,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete object: %w", err)
|
||
}
|
||
|
||
if len(list.Objects) > 0 {
|
||
var objectIDs []cdssdk.ObjectID
|
||
for _, obj := range list.Objects {
|
||
objectIDs = append(objectIDs, obj.ObjectID)
|
||
}
|
||
|
||
err = cdsCli.Object().Delete(cdsapi.ObjectDelete{
|
||
ObjectIDs: objectIDs,
|
||
UserID: userID,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete object: %w", err)
|
||
}
|
||
}
|
||
|
||
err = svc.JobSetSvc().db.UploadData().DeleteFolder(svc.db.DefCtx(), packageID, path)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete object: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (svc *JobSetService) QueryUploaded(queryParams sch.QueryData) ([]uploadersdk.Package, int, int, error) {
|
||
// 查询根目录
|
||
if queryParams.PackageID == -1 {
|
||
packages, err := svc.JobSetSvc().db.UploadData().QueryPackage(svc.db.DefCtx(), queryParams)
|
||
if err != nil {
|
||
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
|
||
}
|
||
return packages, 0, 0, nil
|
||
}
|
||
|
||
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return nil, 0, 0, fmt.Errorf("new scheduler client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(cdsCli)
|
||
|
||
queryListReq := cdsapi.ObjectListByPath{
|
||
UserID: queryParams.UserID,
|
||
PackageID: queryParams.PackageID,
|
||
Path: queryParams.Path,
|
||
IsPrefix: true,
|
||
}
|
||
objList, err := cdsCli.Object().ListByPath(queryListReq)
|
||
if err != nil {
|
||
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
|
||
}
|
||
|
||
folderMap := make(map[string]cdssdk.Object)
|
||
var modifyObjs []cdssdk.Object
|
||
for _, obj := range objList.Objects {
|
||
// 去掉obj中path从0到queryParams.Path这段字符串
|
||
obj.Path = strings.TrimPrefix(obj.Path, queryParams.Path)
|
||
pathArr := strings.Split(obj.Path, "/")
|
||
if len(pathArr) > 2 {
|
||
splitPath := "/" + pathArr[1]
|
||
folderMap[splitPath] = cdssdk.Object{
|
||
ObjectID: -1,
|
||
PackageID: obj.PackageID,
|
||
Path: pathArr[1],
|
||
Size: 0,
|
||
CreateTime: obj.CreateTime,
|
||
}
|
||
continue
|
||
}
|
||
modifyObjs = append(modifyObjs, obj)
|
||
}
|
||
|
||
folders, err := svc.db.UploadData().QueryFolder(svc.db.DefCtx(), queryParams)
|
||
if err != nil {
|
||
return nil, 0, 0, fmt.Errorf("failed to query uploaded data: %w", err)
|
||
}
|
||
|
||
for _, folder := range folders {
|
||
folder.Path = strings.TrimPrefix(folder.Path, queryParams.Path)
|
||
if folder.Path == "" {
|
||
continue
|
||
}
|
||
folderMap[folder.Path] = cdssdk.Object{
|
||
ObjectID: -1,
|
||
PackageID: folder.PackageID,
|
||
Path: folder.Path,
|
||
Size: 0,
|
||
CreateTime: folder.CreateTime,
|
||
}
|
||
}
|
||
|
||
// 遍历folderMap,将folderMap的值赋给objList.Objects
|
||
for _, obj := range folderMap {
|
||
modifyObjs = append(modifyObjs, obj)
|
||
}
|
||
objList.Objects = modifyObjs
|
||
|
||
// 根据orderBy字段排序
|
||
sort.Slice(objList.Objects, func(i, j int) bool {
|
||
if queryParams.OrderBy == sch.OrderByName {
|
||
return objList.Objects[i].Path < objList.Objects[j].Path
|
||
} else if queryParams.OrderBy == sch.OrderBySize {
|
||
return objList.Objects[i].Size < objList.Objects[j].Size
|
||
} else if queryParams.OrderBy == sch.OrderByTime {
|
||
return objList.Objects[i].CreateTime.Unix() < objList.Objects[j].CreateTime.Unix()
|
||
}
|
||
return false
|
||
})
|
||
|
||
totalNum := len(objList.Objects)
|
||
|
||
// 分页返回
|
||
if queryParams.PageSize > 0 {
|
||
start := (queryParams.CurrentPage - 1) * queryParams.PageSize
|
||
end := start + queryParams.PageSize
|
||
if start > totalNum {
|
||
return nil, 0, 0, nil
|
||
}
|
||
if end > totalNum {
|
||
end = totalNum
|
||
}
|
||
objList.Objects = objList.Objects[start:end]
|
||
}
|
||
totalPages := totalNum / queryParams.PageSize
|
||
|
||
var datas []uploadersdk.Package
|
||
data, err := svc.db.UploadData().QueryPackageByID(svc.db.DefCtx(), queryParams.PackageID)
|
||
if err != nil {
|
||
return nil, 0, 0, err
|
||
}
|
||
pkg := uploadersdk.Package{
|
||
PackageID: data.PackageID,
|
||
BucketID: data.BucketID,
|
||
DataType: data.DataType,
|
||
PackageName: data.PackageName,
|
||
JsonData: data.JsonData,
|
||
BindingID: data.BindingID,
|
||
UserID: data.UserID,
|
||
CreateTime: data.CreateTime,
|
||
Objects: objList.Objects,
|
||
UploadedCluster: data.UploadedCluster,
|
||
}
|
||
datas = append(datas, pkg)
|
||
|
||
return datas, totalPages, totalNum, nil
|
||
}
|
||
|
||
func (svc *JobSetService) DataBinding(id uploadersdk.DataID, userID cdssdk.UserID, info sch.DataBinding) error {
|
||
|
||
var bindingData uploadersdk.Binding
|
||
var bindingClusters []uploadersdk.BindingCluster
|
||
var packageID cdssdk.PackageID
|
||
|
||
switch bd := info.(type) {
|
||
case *sch.DatasetBinding:
|
||
pkg, clusterMap, err := svc.queryAndVerify(bd.PackageID, bd.ClusterIDs, userID, bd.Name, sch.DATASET)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
param := dataScheduleParam{
|
||
BindingClusterIDs: bd.ClusterIDs,
|
||
ClusterMap: clusterMap,
|
||
PackageID: bd.PackageID,
|
||
UploadedClusters: pkg.UploadedCluster,
|
||
UserID: userID,
|
||
}
|
||
// 如果目标集群没有数据,则需要尽心数据调度
|
||
err = svc.dataSchedule(param)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
for _, clusterID := range bd.ClusterIDs {
|
||
clusterInfo, ok := clusterMap[clusterID]
|
||
if !ok {
|
||
return fmt.Errorf("cluster %s not found", clusterID)
|
||
}
|
||
|
||
filePath := clusterInfo.StoragePath + sch.Split + sch.DATASET + sch.Split + pkg.PackageName
|
||
datasetParam := &dataset.Dataset{
|
||
Name: bd.Name,
|
||
Description: bd.Description,
|
||
FilePath: filePath,
|
||
Category: dataset.CommonValue(bd.Category),
|
||
}
|
||
|
||
var jsonData []byte
|
||
switch clusterInfo.ClusterName {
|
||
case sch.PlatformModelArts:
|
||
resp, err := svc.hubClient.ModelArts.BindDataset(datasetParam)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
jsonData, err = json.Marshal(resp.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
case sch.PlatformOpenI:
|
||
datasetParam.RepoName = ""
|
||
resp, err := svc.hubClient.OpenI.BindDataset(datasetParam)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
jsonData, err = json.Marshal(resp.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
bindingClusters = append(bindingClusters, uploadersdk.BindingCluster{
|
||
ClusterID: uploadersdk.ClusterID(clusterID),
|
||
Status: sch.SuccessStatus,
|
||
JsonData: string(jsonData),
|
||
})
|
||
}
|
||
|
||
content, err := json.Marshal(bd)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content))
|
||
|
||
packageID = bd.PackageID
|
||
case *sch.CodeBinding:
|
||
pkg, clusterMap, err := svc.queryAndVerify(bd.PackageID, bd.ClusterIDs, userID, bd.Name, sch.CODE)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
param := dataScheduleParam{
|
||
BindingClusterIDs: bd.ClusterIDs,
|
||
ClusterMap: clusterMap,
|
||
PackageID: bd.PackageID,
|
||
UploadedClusters: pkg.UploadedCluster,
|
||
UserID: userID,
|
||
}
|
||
// 如果目标集群没有数据,则需要尽心数据调度
|
||
err = svc.dataSchedule(param)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
for _, clusterID := range bd.ClusterIDs {
|
||
clusterInfo, ok := clusterMap[clusterID]
|
||
if !ok {
|
||
return fmt.Errorf("cluster %s not found", clusterID)
|
||
}
|
||
|
||
filePath := clusterInfo.StoragePath + sch.Split + sch.CODE + sch.Split + pkg.PackageName
|
||
codeParam := &algorithm.Algorithm{
|
||
Name: bd.Name,
|
||
Description: bd.Description,
|
||
CodeDir: filePath,
|
||
//Engine: "",
|
||
BootFile: "",
|
||
}
|
||
|
||
var jsonData []byte
|
||
switch clusterInfo.ClusterName {
|
||
case sch.PlatformModelArts:
|
||
engine := algorithm.Engine{
|
||
EngineName: "Ascend-Powered-Engine",
|
||
EngineVersion: "", // 从URL里提取
|
||
ImageUrl: "", // 从数据库中获取
|
||
InstallSysPackages: true,
|
||
}
|
||
|
||
println(engine)
|
||
|
||
resp, err := svc.hubClient.ModelArts.BindAlgorithm(codeParam)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
jsonData, err = json.Marshal(resp.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
case sch.PlatformOpenI:
|
||
//openi不需要传Engine
|
||
resp, err := svc.hubClient.OpenI.BindAlgorithm(codeParam)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = svc.hubClient.UploadOpenIAlgorithm(codeParam, int(bd.PackageID))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
jsonData, err = json.Marshal(resp.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
bindingClusters = append(bindingClusters, uploadersdk.BindingCluster{
|
||
ClusterID: uploadersdk.ClusterID(clusterID),
|
||
Status: sch.SuccessStatus,
|
||
JsonData: string(jsonData),
|
||
})
|
||
}
|
||
|
||
content, err := json.Marshal(bd)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content))
|
||
|
||
packageID = bd.PackageID
|
||
case *sch.ImageBinding:
|
||
//content, err := json.Marshal(bd)
|
||
//if err != nil {
|
||
// return err
|
||
//}
|
||
//bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content), "")
|
||
//packageID = bd.PackageID
|
||
return fmt.Errorf("not support image binding")
|
||
case *sch.ModelBinding:
|
||
pkg, clusterMap, err := svc.queryAndVerify(bd.PackageID, bd.ClusterIDs, userID, bd.Name, sch.MODEL)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
param := dataScheduleParam{
|
||
BindingClusterIDs: bd.ClusterIDs,
|
||
ClusterMap: clusterMap,
|
||
PackageID: bd.PackageID,
|
||
UploadedClusters: pkg.UploadedCluster,
|
||
UserID: userID,
|
||
}
|
||
// 如果目标集群没有数据,则需要尽心数据调度
|
||
err = svc.dataSchedule(param)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
for _, clusterID := range bd.ClusterIDs {
|
||
clusterInfo, ok := clusterMap[clusterID]
|
||
if !ok {
|
||
return fmt.Errorf("cluster %s not found", clusterID)
|
||
}
|
||
|
||
filePath := clusterInfo.StoragePath + sch.Split + sch.MODEL + sch.Split + pkg.PackageName
|
||
modelParam := &model.Model{
|
||
Name: bd.Name,
|
||
Description: bd.Description,
|
||
Type: bd.ModelType,
|
||
FilePath: filePath,
|
||
//Engine: model.CommonValue(),
|
||
//Version:
|
||
}
|
||
|
||
var jsonData []byte
|
||
switch clusterInfo.ClusterName {
|
||
case sch.PlatformModelArts:
|
||
resp, err := svc.hubClient.ModelArts.BindModel(modelParam)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
jsonData, err = json.Marshal(resp.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
case sch.PlatformOpenI:
|
||
modelParam.RepoName = ""
|
||
resp, err := svc.hubClient.OpenI.BindModel(modelParam)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
jsonData, err = json.Marshal(resp.Data)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
bindingClusters = append(bindingClusters, uploadersdk.BindingCluster{
|
||
ClusterID: uploadersdk.ClusterID(clusterID),
|
||
Status: sch.SuccessStatus,
|
||
JsonData: string(jsonData),
|
||
})
|
||
}
|
||
|
||
content, err := json.Marshal(bd)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
bindingData = getBindingData(id, userID, bd.Type, bd.Name, string(content))
|
||
|
||
packageID = bd.PackageID
|
||
}
|
||
|
||
if bindingData.AccessLevel == "" {
|
||
bindingData.AccessLevel = sch.PrivateAccess
|
||
}
|
||
bindingData.CreateTime = time.Now()
|
||
_, err := svc.db.UploadData().InsertOrUpdateBinding(svc.db.DefCtx(), bindingData, bindingClusters, packageID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type dataScheduleParam struct {
|
||
UserID cdssdk.UserID
|
||
PackageID cdssdk.PackageID
|
||
BindingClusterIDs []schsdk.ClusterID
|
||
UploadedClusters []uploadersdk.Cluster
|
||
ClusterMap map[schsdk.ClusterID]uploadersdk.ClusterMapping
|
||
}
|
||
|
||
func (svc *JobSetService) dataSchedule(param dataScheduleParam) error {
|
||
// 筛选出需要数据调度的集群
|
||
var clusters []uploadersdk.ClusterMapping
|
||
for _, cid := range param.BindingClusterIDs {
|
||
isMatch := false
|
||
for _, cluster := range param.UploadedClusters {
|
||
if cid == cluster.ClusterID {
|
||
isMatch = true
|
||
break
|
||
}
|
||
}
|
||
if !isMatch {
|
||
clusterInfo, ok := param.ClusterMap[cid]
|
||
if ok {
|
||
clusters = append(clusters, clusterInfo)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 进行数据调度
|
||
for _, cluster := range clusters {
|
||
// 数据调度
|
||
_, err := svc.hubClient.LoadPackage(uint(param.PackageID), uint(param.UserID), uint(cluster.StorageID), "")
|
||
if err != nil {
|
||
logger.Error("data schedule failed, error: ", err.Error())
|
||
return err
|
||
}
|
||
// 将调度成功的集群加入到uploadedCluster
|
||
cluster := uploadersdk.Cluster{
|
||
PackageID: param.PackageID,
|
||
ClusterID: cluster.ClusterID,
|
||
StorageID: cluster.StorageID,
|
||
}
|
||
err = svc.db.UploadData().InsertUploadedCluster(svc.db.DefCtx(), cluster)
|
||
if err != nil {
|
||
logger.Error("insert uploadedCluster failed, error: ", err.Error())
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (svc *JobSetService) queryAndVerify(pacakgeID cdssdk.PackageID, clusterIDs []schsdk.ClusterID, userID cdssdk.UserID, name string, dataType string) (*uploadersdk.PackageDAO, map[schsdk.ClusterID]uploadersdk.ClusterMapping, error) {
|
||
// 查询是否已经绑定
|
||
existBinding, err := svc.db.UploadData().GetBindingByName(svc.db.DefCtx(), userID, name, dataType)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
if existBinding.ID != 0 {
|
||
return nil, nil, fmt.Errorf("name %s already exists", name)
|
||
}
|
||
|
||
// 查询package
|
||
pkg, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), pacakgeID)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
if pkg.PackageID == 0 {
|
||
return nil, nil, fmt.Errorf("no package found")
|
||
}
|
||
|
||
// 如果这个package已经被绑定,则不允许再绑定
|
||
if pkg.BindingID != -1 {
|
||
binding, err := svc.db.UploadData().GetBindingsByID(svc.db.DefCtx(), pkg.BindingID)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
return nil, nil, fmt.Errorf("binding already exists, name: " + binding.Name)
|
||
}
|
||
|
||
clusterMap := make(map[schsdk.ClusterID]uploadersdk.ClusterMapping)
|
||
clusters, err := svc.db.UploadData().GetClusterByID(svc.db.DefCtx(), clusterIDs)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
for _, cluster := range clusters {
|
||
clusterMap[cluster.ClusterID] = cluster
|
||
}
|
||
|
||
return &pkg, clusterMap, nil
|
||
}
|
||
|
||
func getBindingData(id uploadersdk.DataID, userID cdssdk.UserID, dataType string, name string, content string) uploadersdk.Binding {
|
||
bindingData := uploadersdk.Binding{
|
||
ID: id,
|
||
Name: name,
|
||
DataType: dataType,
|
||
UserID: userID,
|
||
Content: content,
|
||
//JsonData: jsonData,
|
||
}
|
||
return bindingData
|
||
}
|
||
|
||
func (svc *JobSetService) RemoveBinding(pacakgeIDs []cdssdk.PackageID, bindingIDs []int64) error {
|
||
|
||
if len(pacakgeIDs) > 0 {
|
||
for _, id := range pacakgeIDs {
|
||
pkgDao := uploadersdk.PackageDAO{
|
||
BindingID: -1,
|
||
}
|
||
err := svc.db.UploadData().UpdatePackage(svc.db.DefCtx(), id, pkgDao)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
|
||
if len(bindingIDs) > 0 {
|
||
err := svc.db.UploadData().DeleteBindingsByID(svc.db.DefCtx(), bindingIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (svc *JobSetService) DeleteBinding(IDs []int64) error {
|
||
|
||
err := svc.db.UploadData().DeleteBindingsByID(svc.db.DefCtx(), IDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (svc *JobSetService) QueryBinding(dataType string, param sch.QueryBindingDataParam, filters sch.QueryBindingFilters) ([]uploadersdk.BindingDetail, error) {
|
||
|
||
switch p := param.(type) {
|
||
case *sch.PrivateLevel:
|
||
return svc.queryPrivateBinding(p.UserID, uploadersdk.DataID(p.BindingID), dataType)
|
||
case *sch.PublicLevel:
|
||
var details []uploadersdk.BindingDetail
|
||
datas, err := svc.db.UploadData().GetPublicBindings(svc.db.DefCtx(), p.Type, dataType, p.UserID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for _, data := range datas {
|
||
var info sch.DataBinding
|
||
binding := uploadersdk.Binding{
|
||
DataType: dataType,
|
||
Content: data.Content,
|
||
}
|
||
info, err = unmarshalBinding(binding)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
bindingDetail := uploadersdk.BindingDetail{
|
||
ID: data.ID,
|
||
UserID: data.UserID,
|
||
UserName: data.UserName,
|
||
SSOId: data.SSOId,
|
||
Name: data.Name,
|
||
Info: info,
|
||
AccessLevel: data.AccessLevel,
|
||
CreateTime: data.CreateTime,
|
||
}
|
||
details = append(details, bindingDetail)
|
||
}
|
||
|
||
return details, nil
|
||
case *sch.ApplyLevel:
|
||
var details []uploadersdk.BindingDetail
|
||
datas, err := svc.db.UploadData().GetApplyBindings(svc.db.DefCtx(), p.UserID, p.Type, dataType)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for _, data := range datas {
|
||
var info sch.DataBinding
|
||
|
||
// 如果有指定状态,则只展示指定状态的数据
|
||
if filters.Status != "" && data.Status != filters.Status {
|
||
continue
|
||
}
|
||
if filters.Name != "" && data.Name != filters.Name {
|
||
continue
|
||
}
|
||
|
||
// 只有approved状态的数据才能看到详情
|
||
if data.Status == sch.ApprovedStatus {
|
||
binding := uploadersdk.Binding{
|
||
DataType: dataType,
|
||
Content: data.Content,
|
||
}
|
||
info, err = unmarshalBinding(binding)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
bindingDetail := uploadersdk.BindingDetail{
|
||
ID: data.ID,
|
||
UserID: data.UserID,
|
||
UserName: data.UserName,
|
||
SSOId: data.SSOId,
|
||
Name: data.Name,
|
||
Info: info,
|
||
Status: data.Status,
|
||
AccessLevel: data.AccessLevel,
|
||
CreateTime: data.CreateTime,
|
||
}
|
||
details = append(details, bindingDetail)
|
||
}
|
||
|
||
return details, nil
|
||
}
|
||
|
||
return nil, fmt.Errorf("unknown query binding data type")
|
||
}
|
||
|
||
func (svc *JobSetService) queryPrivateBinding(userID cdssdk.UserID, bindingID uploadersdk.DataID, dataType string) ([]uploadersdk.BindingDetail, error) {
|
||
var details []uploadersdk.BindingDetail
|
||
if bindingID == -1 {
|
||
datas, err := svc.db.UploadData().GetPrivateBindings(svc.db.DefCtx(), userID, dataType)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for _, data := range datas {
|
||
|
||
info, err := unmarshalBinding(data)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
bindingDetail := uploadersdk.BindingDetail{
|
||
ID: data.ID,
|
||
UserID: data.UserID,
|
||
Info: info,
|
||
AccessLevel: data.AccessLevel,
|
||
CreateTime: data.CreateTime,
|
||
}
|
||
details = append(details, bindingDetail)
|
||
}
|
||
|
||
return details, nil
|
||
}
|
||
|
||
data, err := svc.db.UploadData().GetBindingsByID(svc.db.DefCtx(), bindingID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
info, err := unmarshalBinding(*data)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
packages, err := svc.db.UploadData().QueryPackageByBindingID(svc.db.DefCtx(), bindingID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
detail := uploadersdk.BindingDetail{
|
||
ID: bindingID,
|
||
UserID: data.UserID,
|
||
Packages: packages,
|
||
Info: info,
|
||
AccessLevel: data.AccessLevel,
|
||
CreateTime: data.CreateTime,
|
||
}
|
||
details = append(details, detail)
|
||
|
||
return details, nil
|
||
}
|
||
|
||
func unmarshalBinding(data uploadersdk.Binding) (sch.DataBinding, error) {
|
||
var info sch.DataBinding
|
||
|
||
switch data.DataType {
|
||
case sch.DATASET:
|
||
var content sch.DatasetBinding
|
||
err := json.Unmarshal([]byte(data.Content), &content)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
info = &content
|
||
case sch.CODE:
|
||
var content sch.CodeBinding
|
||
err := json.Unmarshal([]byte(data.Content), &content)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
info = &content
|
||
case sch.IMAGE:
|
||
var content sch.ImageBinding
|
||
err := json.Unmarshal([]byte(data.Content), &content)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
info = &content
|
||
case sch.MODEL:
|
||
var content sch.ModelBinding
|
||
err := json.Unmarshal([]byte(data.Content), &content)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
info = &content
|
||
}
|
||
return info, nil
|
||
}
|
||
|
||
func (svc *JobSetService) CreatePackage(userID cdssdk.UserID, name string, dataType string, uploadPriority sch.UploadPriority) error {
|
||
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return fmt.Errorf("new cds client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(cdsCli)
|
||
|
||
bucket, err := svc.db.Access().GetBucketByUserID(svc.db.DefCtx(), userID, dataType)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get bucket: %w", err)
|
||
} else if bucket == nil {
|
||
return fmt.Errorf("bucket not found")
|
||
}
|
||
|
||
// 创建package
|
||
newPackage, err := cdsCli.Package().Create(cdsapi.PackageCreate{
|
||
UserID: userID,
|
||
BucketID: bucket.ID,
|
||
Name: name,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create package: %w", err)
|
||
}
|
||
|
||
pkg := uploadersdk.Package{
|
||
UserID: userID,
|
||
PackageID: newPackage.Package.PackageID,
|
||
PackageName: name,
|
||
BucketID: bucket.ID,
|
||
DataType: dataType,
|
||
UploadPriority: uploadPriority,
|
||
}
|
||
|
||
// 对Package进行预调度,并写入到数据库中
|
||
clusters, err := svc.packageScheduler(pkg.PackageID, uploadPriority)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 写入数据库存档
|
||
err = svc.JobSetSvc().db.UploadData().InsertPackage(svc.db.DefCtx(), pkg, clusters)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (svc *JobSetService) packageScheduler(packageID cdssdk.PackageID, uploadPriority sch.UploadPriority) ([]uploadersdk.Cluster, error) {
|
||
clusterMapping, err := svc.db.UploadData().GetClusterMapping(svc.db.DefCtx())
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query cluster mapping error: %w", err)
|
||
}
|
||
|
||
var clusters []uploadersdk.Cluster
|
||
switch uploadPriority := uploadPriority.(type) {
|
||
case *sch.Preferences:
|
||
// 进行预调度
|
||
clusterID, err := svc.preScheduler.ScheduleJob(uploadPriority.ResourcePriorities, clusterMapping)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("pre scheduling: %w", err)
|
||
}
|
||
|
||
storageID, ok := clusterMapping[*clusterID]
|
||
if !ok {
|
||
return nil, fmt.Errorf("cluster %d not found", clusterID)
|
||
}
|
||
cluster := uploadersdk.Cluster{
|
||
PackageID: packageID,
|
||
ClusterID: *clusterID,
|
||
StorageID: storageID,
|
||
}
|
||
clusters = append(clusters, cluster)
|
||
case *sch.SpecifyCluster:
|
||
// 指定集群
|
||
for _, clusterID := range uploadPriority.Clusters {
|
||
storageID, ok := clusterMapping[clusterID]
|
||
if !ok {
|
||
logger.Warnf("cluster %d not found", clusterID)
|
||
continue
|
||
}
|
||
cluster := uploadersdk.Cluster{
|
||
PackageID: packageID,
|
||
ClusterID: clusterID,
|
||
StorageID: storageID,
|
||
}
|
||
clusters = append(clusters, cluster)
|
||
}
|
||
}
|
||
|
||
if len(clusters) == 0 {
|
||
return nil, errors.New("no storage is available")
|
||
}
|
||
|
||
//for _, clst := range clusters {
|
||
// err := svc.db.UploadData().InsertUploadedCluster(svc.db.DefCtx(), clst)
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
//}
|
||
|
||
return clusters, nil
|
||
}
|
||
|
||
func (svc *JobSetService) DeletePackage(userID cdssdk.UserID, packageID cdssdk.PackageID) error {
|
||
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return fmt.Errorf("new cds client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(cdsCli)
|
||
|
||
err = cdsCli.Package().Delete(cdsapi.PackageDelete{
|
||
UserID: userID,
|
||
PackageID: packageID,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("delete package: %w", err)
|
||
}
|
||
|
||
err = svc.JobSetSvc().db.UploadData().DeletePackage(svc.db.DefCtx(), userID, packageID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (svc *JobSetService) QueryResource(queryResource sch.ResourceRange) ([]sch.ClusterDetail, error) {
|
||
|
||
clusterDetails, err := svc.getClusterResources()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var results []sch.ClusterDetail
|
||
for _, cluster := range clusterDetails {
|
||
if cluster.Resources == nil || len(cluster.Resources) == 0 {
|
||
continue
|
||
}
|
||
|
||
ok := isAppropriateResources(cluster.Resources, queryResource)
|
||
|
||
if ok {
|
||
results = append(results, cluster)
|
||
}
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
func isAppropriateResources(resources []sch.ClusterResource, queryResource sch.ResourceRange) bool {
|
||
for _, resource := range resources {
|
||
if resource.Resource.Type == queryResource.Type {
|
||
//ok := compareResource(queryResource.GPU.Min, queryResource.GPU.Max, resource.Resource.Available.Value)
|
||
//if !ok {
|
||
// return false
|
||
//}
|
||
|
||
if resource.BaseResources == nil || len(resource.BaseResources) == 0 {
|
||
return false
|
||
}
|
||
ok := false
|
||
|
||
for _, baseResource := range resource.BaseResources {
|
||
if baseResource.Type == sch.ResourceTypeCPU {
|
||
ok = compareResource(queryResource.CPU.Min, queryResource.CPU.Max, baseResource.Available.Value)
|
||
if !ok {
|
||
return false
|
||
}
|
||
}
|
||
if baseResource.Type == sch.ResourceTypeMemory {
|
||
ok = compareResource(queryResource.Memory.Min, queryResource.Memory.Max, baseResource.Available.Value)
|
||
if !ok {
|
||
return false
|
||
}
|
||
}
|
||
if baseResource.Type == sch.ResourceTypeStorage {
|
||
ok = compareResource(queryResource.Storage.Min, queryResource.Storage.Max, baseResource.Available.Value)
|
||
if !ok {
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
func compareResource(min float64, max float64, v float64) bool {
|
||
if min > max {
|
||
return false
|
||
}
|
||
|
||
if min == 0 && max == 0 {
|
||
return true
|
||
}
|
||
|
||
if v >= min && v <= max {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
func (svc *JobSetService) ResourceRange() ([]sch.ResourceRange, error) {
|
||
clusterDetails, err := svc.getClusterResources()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 初始化一个空的 map 来存储资源类型的 Range 数据
|
||
resourceMap := make(map[sch.ResourceType]sch.ResourceRange)
|
||
|
||
// 遍历所有 ClusterDetail
|
||
for _, cluster := range clusterDetails {
|
||
var CPUValue float64
|
||
var MemValue float64
|
||
var StorageValue float64
|
||
|
||
for _, resource := range cluster.Resources {
|
||
for _, baseResource := range resource.BaseResources {
|
||
// 检查资源类型,跳过不需要统计的类型
|
||
switch baseResource.Type {
|
||
case sch.ResourceTypeCPU:
|
||
CPUValue = baseResource.Available.Value
|
||
case sch.ResourceTypeMemory:
|
||
MemValue = baseResource.Available.Value
|
||
case sch.ResourceTypeStorage:
|
||
StorageValue = baseResource.Available.Value
|
||
}
|
||
}
|
||
|
||
}
|
||
// 遍历每个 ClusterDetail 的资源列表
|
||
for _, resource := range cluster.Resources {
|
||
|
||
// 获取资源类型的 key
|
||
resourceType := resource.Resource.Type
|
||
// 获取资源的 Available Value
|
||
//availableValue := resource.Available.Value
|
||
|
||
// 获取现有的 ResourceRange
|
||
resourceRange, exists := resourceMap[resourceType]
|
||
if !exists {
|
||
// 如果该资源类型还没有添加过,初始化一个新的 Range
|
||
resourceRange = sch.ResourceRange{
|
||
Type: resourceType,
|
||
GPU: sch.Range{}, // 初始的 GPU 范围
|
||
GPUNumber: 0, // 初始的 GPU 数量
|
||
CPU: sch.Range{}, // 初始的 CPU 范围
|
||
Memory: sch.Range{}, // 初始的 Memory 范围
|
||
Storage: sch.Range{}, // 初始的 Storage 范围
|
||
}
|
||
}
|
||
|
||
if CPUValue < resourceRange.CPU.Min || resourceRange.CPU.Min == 0 {
|
||
resourceRange.CPU.Min = CPUValue
|
||
}
|
||
if CPUValue > resourceRange.CPU.Max {
|
||
resourceRange.CPU.Max = CPUValue
|
||
}
|
||
|
||
if MemValue < resourceRange.Memory.Min || resourceRange.Memory.Min == 0 {
|
||
resourceRange.Memory.Min = MemValue
|
||
}
|
||
if MemValue > resourceRange.Memory.Max {
|
||
resourceRange.Memory.Max = MemValue
|
||
}
|
||
|
||
if StorageValue < resourceRange.Storage.Min || resourceRange.Storage.Min == 0 {
|
||
resourceRange.Storage.Min = StorageValue
|
||
}
|
||
if StorageValue > resourceRange.Storage.Max {
|
||
resourceRange.Storage.Max = StorageValue
|
||
}
|
||
|
||
// 增加资源数量统计
|
||
resourceRange.GPUNumber++
|
||
|
||
// 更新 resourceMap 中对应资源类型的 ResourceRange
|
||
resourceMap[resourceType] = resourceRange
|
||
}
|
||
}
|
||
|
||
// 将 map 转换为一个 slice
|
||
var result []sch.ResourceRange
|
||
for _, rangeData := range resourceMap {
|
||
result = append(result, rangeData)
|
||
}
|
||
|
||
// 返回统计结果
|
||
return result, nil
|
||
}
|
||
|
||
func (svc *JobSetService) getClusterResources() ([]sch.ClusterDetail, error) {
|
||
schCli, err := schglb.PCMSchePool.Acquire()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("new scheduler client: %w", err)
|
||
}
|
||
defer schglb.PCMSchePool.Release(schCli)
|
||
|
||
clusterMapping, err := svc.db.UploadData().GetClusterMapping(svc.db.DefCtx())
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query cluster mapping: %w", err)
|
||
}
|
||
|
||
// 查询指定算力中心
|
||
clusterIDs := make([]schsdk.ClusterID, 0, len(clusterMapping))
|
||
|
||
for id := range clusterMapping {
|
||
clusterIDs = append(clusterIDs, id)
|
||
}
|
||
|
||
clusterDetails, err := schCli.GetClusterInfo(sch.GetClusterInfoReq{
|
||
IDs: clusterIDs,
|
||
})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get cluster info: %w", err)
|
||
}
|
||
|
||
if len(clusterDetails) == 0 {
|
||
return nil, errors.New("no cluster found")
|
||
}
|
||
|
||
return clusterDetails, nil
|
||
}
|
||
|
||
func (svc *JobSetService) QueryImages(IDs []int64) ([]sch.ClusterImage, error) {
|
||
images, err := svc.db.UploadData().GetImageByID(svc.db.DefCtx(), IDs)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query images: %w", err)
|
||
}
|
||
|
||
return images, nil
|
||
}
|
||
|
||
func (svc *JobSetService) ClonePackage(userID cdssdk.UserID, param uploadersdk.PackageCloneParam, cloneType string) (*cdssdk.Package, error) {
|
||
clonePackageID := param.PackageID
|
||
pkg := cdssdk.Package{
|
||
PackageID: param.PackageID,
|
||
Name: param.PackageName,
|
||
//BucketID: param.BucketID,
|
||
}
|
||
|
||
parentPkg, err := svc.db.UploadData().GetParentClonePackageByPkgID(svc.db.DefCtx(), param.PackageID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query clone package: %w", err)
|
||
}
|
||
|
||
// 如果子算法列表已经存在数据,则执行clone操作新增子算法
|
||
switch cloneType {
|
||
case sch.ChildrenType:
|
||
// 判断父算法是否已经创建
|
||
if parentPkg.ParentPackageID == 0 {
|
||
return nil, fmt.Errorf("code parent package is not exists")
|
||
}
|
||
|
||
// 查询package,用于获取bucketID
|
||
queryPkg, err := svc.db.UploadData().GetByPackageID(svc.db.DefCtx(), param.PackageID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query package: %w", err)
|
||
}
|
||
|
||
// 复制package
|
||
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("new cds client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(cdsCli)
|
||
|
||
version := strconv.FormatInt(time.Now().Unix(), 10)
|
||
packageName := fmt.Sprintf("%s_%s", param.PackageName, version)
|
||
cloneReq := cdsapi.PackageClone{
|
||
PackageID: param.PackageID,
|
||
Name: packageName,
|
||
BucketID: queryPkg.BucketID,
|
||
UserID: userID,
|
||
}
|
||
cloneResp, err := cdsCli.Package().Clone(cloneReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("clone package: %w", err)
|
||
}
|
||
clonePackageID = cloneResp.Package.PackageID
|
||
pkg = cloneResp.Package
|
||
case sch.ParentType:
|
||
// 判断父算法是否已经创建
|
||
if parentPkg.ParentPackageID != 0 {
|
||
return nil, fmt.Errorf("code parent package alread exists")
|
||
}
|
||
}
|
||
|
||
packageCloneDAO := uploadersdk.PackageCloneDAO{
|
||
ParentPackageID: param.PackageID,
|
||
ClonePackageID: clonePackageID,
|
||
Name: param.Name,
|
||
Description: param.Description,
|
||
BootstrapObjectID: param.BootstrapObjectID,
|
||
ParentImageID: param.ParentImageID,
|
||
ImageID: param.ImageID,
|
||
ClusterID: param.ClusterID,
|
||
CreateTime: time.Now(),
|
||
}
|
||
|
||
// 将package添加到version表
|
||
err = svc.db.UploadData().InsertClonePackage(svc.db.DefCtx(), packageCloneDAO)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("insert package version: %w", err)
|
||
}
|
||
|
||
// 返回package
|
||
return &pkg, nil
|
||
}
|
||
|
||
func (svc *JobSetService) QueryClonePackage(packageID cdssdk.PackageID, userID cdssdk.UserID, dataType string) ([]uploadersdk.PackageCloneVO, error) {
|
||
|
||
// 获取父算法列表
|
||
if packageID == -1 {
|
||
pkgs, err := svc.db.UploadData().GetCloneParentPackage(svc.db.DefCtx(), userID, dataType)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query clone parent package: %w", err)
|
||
}
|
||
|
||
return pkgs, nil
|
||
}
|
||
|
||
// 获取子算法
|
||
pkgs, err := svc.db.UploadData().GetChildrenClonePackageByPkgID(svc.db.DefCtx(), packageID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query children package: %w", err)
|
||
}
|
||
|
||
return pkgs, nil
|
||
}
|
||
|
||
func (svc *JobSetService) RemoveClonePackage(userID cdssdk.UserID, cloneType string, packageIDs []cdssdk.PackageID) error {
|
||
|
||
if cloneType == sch.ParentType {
|
||
err := svc.db.UploadData().RemoveClonePackage(svc.db.DefCtx(), packageIDs, []cdssdk.PackageID{})
|
||
if err != nil {
|
||
return fmt.Errorf("remove clone package: %w", err)
|
||
}
|
||
} else {
|
||
cdsCli, err := schglb.CloudreamStoragePool.Acquire()
|
||
if err != nil {
|
||
return fmt.Errorf("new cds client: %w", err)
|
||
}
|
||
defer schglb.CloudreamStoragePool.Release(cdsCli)
|
||
|
||
for _, id := range packageIDs {
|
||
packageDelete := cdsapi.PackageDelete{
|
||
UserID: userID,
|
||
PackageID: id,
|
||
}
|
||
err = cdsCli.Package().Delete(packageDelete)
|
||
if err != nil {
|
||
return fmt.Errorf("delete package: %w", err)
|
||
}
|
||
}
|
||
|
||
err = svc.db.UploadData().RemoveClonePackage(svc.db.DefCtx(), []cdssdk.PackageID{}, packageIDs)
|
||
if err != nil {
|
||
return fmt.Errorf("remove clone package: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|