224 lines
6.0 KiB
Go
224 lines
6.0 KiB
Go
package schedule
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
|
"gopkg.in/yaml.v2"
|
|
"strings"
|
|
)
|
|
|
|
type ScheduleRunTaskLogic struct {
|
|
logx.Logger
|
|
ctx context.Context
|
|
svcCtx *svc.ServiceContext
|
|
}
|
|
|
|
func NewScheduleRunTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ScheduleRunTaskLogic {
|
|
return &ScheduleRunTaskLogic{
|
|
Logger: logx.WithContext(ctx),
|
|
ctx: ctx,
|
|
svcCtx: svcCtx,
|
|
}
|
|
}
|
|
|
|
func (l *ScheduleRunTaskLogic) ScheduleRunTask(req *types.RunTaskReq) (resp *types.RunTaskResp, err error) {
|
|
// find task
|
|
task, err := l.svcCtx.Scheduler.AiStorages.GetTaskById(req.TaskID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if task == nil {
|
|
return nil, errors.New("task not found ")
|
|
}
|
|
|
|
if task.Status == constants.Cancelled {
|
|
return nil, errors.New("task has been cancelled ")
|
|
}
|
|
|
|
var clusters []*strategy.AssignedCluster
|
|
err = yaml.Unmarshal([]byte(task.YamlString), &clusters)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
opt := &option.AiOption{
|
|
AdapterId: ADAPTERID,
|
|
TaskName: task.Name,
|
|
}
|
|
// update assignedClusters
|
|
err = updateClustersByScheduledDatas(task.Id, &clusters, req.ScheduledDatas)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
aiSchdl, err := schedulers.NewAiScheduler(l.ctx, "", l.svcCtx.Scheduler, opt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
results, err := l.svcCtx.Scheduler.AssignAndSchedule(aiSchdl, scheduler.SUBMIT_MODE_STORAGE_SCHEDULE, clusters)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rs := (results).([]*schedulers.AiResult)
|
|
|
|
err = l.SaveResult(task, rs, opt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (l *ScheduleRunTaskLogic) SaveResult(task *models.Task, results []*schedulers.AiResult, opt *option.AiOption) error {
|
|
|
|
for _, r := range results {
|
|
|
|
opt.ComputeCard = strings.ToUpper(r.Card)
|
|
|
|
adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(r.AdapterId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(r.ClusterId)
|
|
|
|
err = l.svcCtx.Scheduler.AiStorages.SaveAiTask(task.Id, opt, adapterName, r.ClusterId, clusterName, r.JobId, constants.Saved, r.Msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(r.AdapterId, adapterName, r.ClusterId, clusterName, r.TaskName, "create", "任务创建中")
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
func updateClustersByScheduledDatas(taskId int64, assignedClusters *[]*strategy.AssignedCluster, scheduledDatas []*types.DataScheduleResults) error {
|
|
for _, cluster := range *assignedClusters {
|
|
for _, data := range scheduledDatas {
|
|
switch data.DataType {
|
|
case "dataset":
|
|
for _, result := range data.Results {
|
|
if !result.Status {
|
|
continue
|
|
}
|
|
for _, c := range result.Clusters {
|
|
if cluster.ClusterId == c.ClusterID {
|
|
if c.JsonData == "" {
|
|
continue
|
|
}
|
|
jsonData := struct {
|
|
Name string `json:"name"`
|
|
Id string `json:"id"`
|
|
}{}
|
|
err := json.Unmarshal([]byte(c.JsonData), &jsonData)
|
|
if err != nil {
|
|
return fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "dataset")
|
|
}
|
|
cluster.DatasetId = jsonData.Id
|
|
}
|
|
}
|
|
}
|
|
case "image":
|
|
for _, result := range data.Results {
|
|
if !result.Status {
|
|
continue
|
|
}
|
|
for _, c := range result.Clusters {
|
|
if cluster.ClusterId == c.ClusterID {
|
|
if c.JsonData == "" {
|
|
continue
|
|
}
|
|
jsonData := struct {
|
|
Name string `json:"name"`
|
|
Id string `json:"id"`
|
|
}{}
|
|
err := json.Unmarshal([]byte(c.JsonData), &jsonData)
|
|
if err != nil {
|
|
return fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "image")
|
|
}
|
|
cluster.ImageId = jsonData.Id
|
|
}
|
|
}
|
|
}
|
|
case "code":
|
|
for _, result := range data.Results {
|
|
if !result.Status {
|
|
continue
|
|
}
|
|
for _, c := range result.Clusters {
|
|
if cluster.ClusterId == c.ClusterID {
|
|
if c.JsonData == "" {
|
|
continue
|
|
}
|
|
jsonData := struct {
|
|
Name string `json:"name"`
|
|
Id string `json:"id"`
|
|
}{}
|
|
err := json.Unmarshal([]byte(c.JsonData), &jsonData)
|
|
if err != nil {
|
|
return fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "code")
|
|
}
|
|
cluster.CodeId = jsonData.Id
|
|
}
|
|
}
|
|
}
|
|
case "model":
|
|
for _, result := range data.Results {
|
|
if !result.Status {
|
|
continue
|
|
}
|
|
for _, c := range result.Clusters {
|
|
if cluster.ClusterId == c.ClusterID {
|
|
if c.JsonData == "" {
|
|
continue
|
|
}
|
|
jsonData := struct {
|
|
Name string `json:"name"`
|
|
Id string `json:"id"`
|
|
}{}
|
|
err := json.Unmarshal([]byte(c.JsonData), &jsonData)
|
|
if err != nil {
|
|
return fmt.Errorf("jsonData convert failed, task %d, cluster %s, datatype %s", taskId, cluster.ClusterId, "model")
|
|
}
|
|
cluster.ModelId = jsonData.Id
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, cluster := range *assignedClusters {
|
|
if cluster.DatasetId == "" {
|
|
return fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "DatasetId")
|
|
}
|
|
|
|
if cluster.ImageId == "" {
|
|
return fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "ImageId")
|
|
}
|
|
|
|
if cluster.CodeId == "" {
|
|
return fmt.Errorf("failed to run task %d, cluster %s cannot find %s", taskId, cluster.ClusterId, "CodeId")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|