pcm-coordinator/internal/logic/schedule/schedulecreatetasklogic.go

418 lines
12 KiB
Go

package schedule
import (
"context"
"fmt"
"github.com/pkg/errors"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/common"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/utils/task"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
"gopkg.in/yaml.v3"
"slices"
"strings"
"time"
"github.com/zeromicro/go-zero/core/logx"
)
const (
TRAINNING_TASK_REPLICA = 1
TRAINNING_TASK_SUFFIX_LEN = 10
QUERY_RESOURCE_RETRY = 3
)
type ClustersWithDataDistributes struct {
Clusters []*strategy.AssignedCluster
DataDistributes *types.DataDistribute
}
type ScheduleCreateTaskLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
queryResource *QueryResourcesLogic
}
func NewScheduleCreateTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ScheduleCreateTaskLogic {
return &ScheduleCreateTaskLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
queryResource: NewQueryResourcesLogic(ctx, svcCtx),
}
}
func generateFilteredDataDistributes(clusters []*strategy.AssignedCluster, distribute types.DataDistribute) *ClustersWithDataDistributes {
var clusterIds []string
for _, c := range clusters {
clusterIds = append(clusterIds, c.ClusterId)
}
clustersWithDataDistributes := &ClustersWithDataDistributes{
Clusters: clusters,
DataDistributes: &types.DataDistribute{
Dataset: make([]*types.DatasetDistribute, 0),
Image: make([]*types.ImageDistribute, 0),
Model: make([]*types.ModelDistribute, 0),
Code: make([]*types.CodeDistribute, 0),
},
}
for _, datasetDistribute := range distribute.Dataset {
dataset := &types.DatasetDistribute{}
dataset.DataName = datasetDistribute.DataName
dataset.PackageID = datasetDistribute.PackageID
clusterScheduledList := make([]*types.ClusterScheduled, 0)
if len(datasetDistribute.Clusters) != 0 {
for _, cluster := range datasetDistribute.Clusters {
if slices.Contains(clusterIds, cluster.ClusterID) {
clusterScheduledList = append(clusterScheduledList, cluster)
}
}
}
dataset.Clusters = clusterScheduledList
clustersWithDataDistributes.DataDistributes.Dataset = append(clustersWithDataDistributes.DataDistributes.Dataset, dataset)
}
for _, imageDistribute := range distribute.Image {
image := &types.ImageDistribute{}
image.DataName = imageDistribute.DataName
image.PackageID = imageDistribute.PackageID
clusterScheduledList := make([]*types.ClusterScheduled, 0)
if len(imageDistribute.Clusters) != 0 {
for _, cluster := range imageDistribute.Clusters {
if slices.Contains(clusterIds, cluster.ClusterID) {
clusterScheduledList = append(clusterScheduledList, cluster)
}
}
}
image.Clusters = clusterScheduledList
clustersWithDataDistributes.DataDistributes.Image = append(clustersWithDataDistributes.DataDistributes.Image, image)
}
for _, codeDistribute := range distribute.Code {
code := &types.CodeDistribute{}
code.DataName = codeDistribute.DataName
code.PackageID = codeDistribute.PackageID
code.Output = codeDistribute.Output
clusterScheduledList := make([]*types.ClusterScheduled, 0)
if len(codeDistribute.Clusters) != 0 {
for _, cluster := range codeDistribute.Clusters {
if slices.Contains(clusterIds, cluster.ClusterID) {
clusterScheduledList = append(clusterScheduledList, cluster)
}
}
}
code.Clusters = clusterScheduledList
clustersWithDataDistributes.DataDistributes.Code = append(clustersWithDataDistributes.DataDistributes.Code, code)
}
for _, modelDistribute := range distribute.Model {
model := &types.ModelDistribute{}
model.DataName = modelDistribute.DataName
model.PackageID = modelDistribute.PackageID
clusterScheduledList := make([]*types.ClusterScheduled, 0)
if len(modelDistribute.Clusters) != 0 {
for _, cluster := range modelDistribute.Clusters {
if slices.Contains(clusterIds, cluster.ClusterID) {
clusterScheduledList = append(clusterScheduledList, cluster)
}
}
}
model.Clusters = clusterScheduledList
clustersWithDataDistributes.DataDistributes.Model = append(clustersWithDataDistributes.DataDistributes.Model, model)
}
return clustersWithDataDistributes
}
func (l *ScheduleCreateTaskLogic) ScheduleCreateTask(req *types.CreateTaskReq) (resp *types.CreateTaskResp, err error) {
resp = &types.CreateTaskResp{}
err = task.ValidateJobResources(req.JobResources, "training")
if err != nil {
return nil, err
}
taskName, err := l.svcCtx.Scheduler.AiService.HandleDuplicateTaskName(req.Name, "training")
if err != nil {
return nil, err
}
var clusters []string
if len(req.JobResources.Clusters) == 1 {
clusters = append(clusters, req.JobResources.Clusters[0].ClusterID)
schedatas, err := l.generateScheduleResult(req.DataDistributes, clusters)
if err != nil {
return nil, err
}
assignedClusters := task.CopyParams([]*strategy.AssignedCluster{{
ClusterId: req.JobResources.Clusters[0].ClusterID, Replicas: 1,
}}, req.JobResources.Clusters, "")
// filter data distribution
clustersWithDataDistributes := generateFilteredDataDistributes(assignedClusters, req.DataDistributes)
taskId, err := l.createTask(taskName, req.Description, req.UserId, req.JobResources.ScheduleStrategy, clustersWithDataDistributes, req.Token, req.UserIp)
if err != nil {
return nil, err
}
resp.ScheduleDatas = schedatas
resp.TaskID = taskId
resp.TaskName = taskName
return resp, nil
} else {
assignedClusters, err := l.getAssignedClustersByStrategy(&req.JobResources, &req.DataDistributes)
if err != nil {
return nil, err
}
if len(assignedClusters) == 0 {
return nil, fmt.Errorf("failed to create task, no scheduled cluster found")
}
for _, c := range assignedClusters {
clusters = append(clusters, c.ClusterId)
}
schedatas, err := l.generateScheduleResult(req.DataDistributes, clusters)
if err != nil {
return nil, err
}
// filter data distribution
clustersWithDataDistributes := generateFilteredDataDistributes(assignedClusters, req.DataDistributes)
taskId, err := l.createTask(taskName, req.Description, req.UserId, req.JobResources.ScheduleStrategy, clustersWithDataDistributes, req.Token, req.UserIp)
if err != nil {
return nil, err
}
resp.ScheduleDatas = schedatas
resp.TaskID = taskId
resp.TaskName = taskName
return resp, nil
}
}
func (l *ScheduleCreateTaskLogic) getAssignedClustersByStrategy(resources *types.JobResources, dataDistribute *types.DataDistribute) ([]*strategy.AssignedCluster, error) {
var assignedClusters []*strategy.AssignedCluster
switch resources.ScheduleStrategy {
case strategy.LEASTLOADFIRST:
var resSpecs []*collector.ResourceSpec
var resCount int
for i := 0; i < QUERY_RESOURCE_RETRY; i++ {
defer time.Sleep(time.Second)
qResources, err := l.queryResource.QueryResourcesByClusterId(nil, "Train")
if err != nil {
continue
}
for _, resource := range qResources {
if resource.Resources != nil {
resCount++
}
}
if resCount >= 1 {
resSpecs = qResources
break
} else {
resCount = 0
continue
}
}
if resCount == 0 {
return nil, fmt.Errorf("failed to create task, resources counting fails")
}
strtg := strategy.NewLeastLoadFirst(TRAINNING_TASK_REPLICA, resSpecs)
clusters, err := strtg.Schedule()
if err != nil {
return nil, err
}
assignedClusters = task.CopyParams(clusters, resources.Clusters, "")
case strategy.DATA_LOCALITY:
strtg := strategy.NewDataLocality(TRAINNING_TASK_REPLICA, dataDistribute)
clusters, err := strtg.Schedule()
if err != nil {
return nil, err
}
assignedClusters = task.CopyParams(clusters, resources.Clusters, "")
default:
return nil, errors.New("no strategy has been chosen")
}
return assignedClusters, nil
}
func (l *ScheduleCreateTaskLogic) createTask(taskName string, desc string, userId int64, strategyName string, clustersWithDataDistributes *ClustersWithDataDistributes, token string, userIp string) (int64, error) {
var synergyStatus int64
if len(clustersWithDataDistributes.Clusters) > 1 {
synergyStatus = 1
}
y, err := yaml.Marshal(clustersWithDataDistributes)
if err != nil {
fmt.Printf("Error while Marshaling. %v", err)
}
taskId, err := l.svcCtx.Scheduler.CreateTask(taskName, desc, userId, synergyStatus, strategyName, string(y), token, userIp, &l.svcCtx.Config)
if err != nil {
return 0, err
}
return taskId, nil
}
func (l *ScheduleCreateTaskLogic) generateScheduleResult(distribute types.DataDistribute, clusters []string) ([]*types.ScheduleData, error) {
var schedatas []*types.ScheduleData
for _, d := range distribute.Dataset {
data := &types.ScheduleData{
DataType: "dataset",
PackageID: d.PackageID,
ClusterIDs: make([]string, 0),
}
var cSlc []string
for _, cluster := range d.Clusters {
cSlc = append(cSlc, cluster.ClusterID)
}
for _, cluster := range clusters {
if !slices.Contains(cSlc, cluster) {
data.ClusterIDs = append(data.ClusterIDs, cluster)
} else {
continue
}
}
if len(data.ClusterIDs) != 0 {
schedatas = append(schedatas, data)
}
}
for _, d := range distribute.Code {
data := &types.ScheduleData{
DataType: "code",
PackageID: d.PackageID,
ClusterIDs: make([]string, 0),
}
var cSlc []string
for _, cluster := range d.Clusters {
cSlc = append(cSlc, cluster.ClusterID)
}
for _, cluster := range clusters {
if !slices.Contains(cSlc, cluster) {
data.ClusterIDs = append(data.ClusterIDs, cluster)
} else {
continue
}
}
if len(data.ClusterIDs) != 0 {
schedatas = append(schedatas, data)
}
}
for _, d := range distribute.Image {
data := &types.ScheduleData{
DataType: "image",
PackageID: d.PackageID,
ClusterIDs: make([]string, 0),
}
var cSlc []string
for _, cluster := range d.Clusters {
cSlc = append(cSlc, cluster.ClusterID)
}
for _, cluster := range clusters {
if !slices.Contains(cSlc, cluster) {
data.ClusterIDs = append(data.ClusterIDs, cluster)
} else {
continue
}
}
if len(data.ClusterIDs) != 0 {
schedatas = append(schedatas, data)
}
}
for _, d := range distribute.Model {
data := &types.ScheduleData{
DataType: "model",
PackageID: d.PackageID,
ClusterIDs: make([]string, 0),
}
var cSlc []string
for _, cluster := range d.Clusters {
cSlc = append(cSlc, cluster.ClusterID)
}
for _, cluster := range clusters {
if !slices.Contains(cSlc, cluster) {
data.ClusterIDs = append(data.ClusterIDs, cluster)
} else {
continue
}
}
if len(data.ClusterIDs) != 0 {
schedatas = append(schedatas, data)
}
}
if len(schedatas) != 0 {
err := l.updateStorageType(&schedatas)
if err != nil {
return nil, err
}
}
return schedatas, nil
}
func (l *ScheduleCreateTaskLogic) updateStorageType(schedatas *[]*types.ScheduleData) error {
for _, s := range *schedatas {
var storageType string
var sTypes []string
for _, id := range s.ClusterIDs {
cluster, err := l.svcCtx.Scheduler.AiStorages.GetClustersById(id)
if err != nil {
return err
}
stype, ok := storeLink.StorageTypeMap[strings.Title(cluster.Name)]
if ok {
sTypes = append(sTypes, stype)
}
}
sTypes = common.Unique(sTypes)
for _, st := range sTypes {
storageType += st + storeLink.COMMA
}
storageType = strings.TrimSuffix(storageType, storeLink.COMMA)
s.StorageType = storageType
}
return nil
}