diff --git a/internal/scheduler/service/utils/task/tasksync/infer.go b/internal/scheduler/service/utils/task/tasksync/infer.go new file mode 100644 index 000000000..773088857 --- /dev/null +++ b/internal/scheduler/service/utils/task/tasksync/infer.go @@ -0,0 +1,390 @@ +package tasksync + +import ( + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/config" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/utils/jcs" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/storeLink" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "net/http" + "strconv" + "sync" + "time" +) + +type SyncInfer struct { + mu sync.Mutex + aiStorages *database.AiStorage + inferenceAdapterMap map[string]map[string]inference.ICluster + config *config.Config +} + +func NewInferTask(storage *database.AiStorage, inferenceAdapterMap map[string]map[string]inference.ICluster, config *config.Config) *SyncInfer { + return &SyncInfer{ + aiStorages: storage, + inferenceAdapterMap: inferenceAdapterMap, + config: config, + } +} + +func (s *SyncInfer) UpdateDeployInstanceStatusBatch(insList []*models.AiInferDeployInstance, needfilter bool) { + s.mu.Lock() + defer s.mu.Unlock() + list := make([]*models.AiInferDeployInstance, len(insList)) + copy(list, insList) + + if needfilter { + for i := len(list) - 1; i >= 0; i-- { + if list[i].Status == constants.Running || list[i].Status == constants.Stopped || list[i].Status == constants.Failed { + list = append(list[:i], list[i+1:]...) + } + } + } + + if len(list) == 0 { + return + } + + buffer := make(chan bool, 3) + for _, instance := range list { + buffer <- true + go s.UpdateDeployInstanceStatus(instance, false, buffer) + } +} + +func (s *SyncInfer) UpdateDeployTaskStatus() { + list, err := s.aiStorages.GetAllDeployTasks() + if err != nil { + return + } + + ins := list[0] + for i := range list { + uTime, _ := time.Parse(time.RFC3339, ins.UpdateTime) + latest, _ := time.Parse(time.RFC3339, list[i].UpdateTime) + if latest.After(uTime) { + ins = list[i] + } + } + inslist, err := s.aiStorages.GetInstanceListByDeployTaskId(ins.Id) + if err != nil { + return + } + + buffer := make(chan bool, 2) + for _, instance := range inslist { + buffer <- true + go s.UpdateDeployInstanceStatus(instance, false, buffer) + } +} + +func (s *SyncInfer) UpdateDeployInstanceStatus(instance *models.AiInferDeployInstance, updatetime bool, ch chan bool) { + amap, found := s.inferenceAdapterMap[strconv.FormatInt(instance.AdapterId, 10)] + if !found { + if ch != nil { + <-ch + return + } + return + } + cmap, found := amap[strconv.FormatInt(instance.ClusterId, 10)] + if !found { + if ch != nil { + <-ch + return + } + return + } + h := http.Request{} + ins, err := cmap.GetInferDeployInstance(h.Context(), instance.InstanceId) + if err != nil { + if ch != nil { + <-ch + return + } + return + } + switch instance.ClusterType { + case storeLink.TYPE_OCTOPUS: + switch ins.Status { + case "running": + if instance.Status == constants.Running { + if ch != nil { + <-ch + return + } + return + } + url := ins.InferUrl + err := s.ReportInferenceStatusMessages(instance, instance.InstanceName, strconv.FormatInt(instance.DeployInstanceTaskId, 10), strconv.FormatInt(instance.ClusterId, 10), url, true, "") + if err != nil { + logx.Errorf("############ Report Infer Task Status Message Error %s", err.Error()) + } + instance.Status = constants.Running + case "stopped": + if instance.Status == constants.Stopped { + if ch != nil { + <-ch + return + } + return + } + instance.Status = constants.Stopped + default: + instance.Status = ins.Status + } + case storeLink.TYPE_MODELARTS: + switch ins.Status { + case "running": + if instance.Status == constants.Running { + if ch != nil { + <-ch + return + } + return + } + url := ins.InferUrl + err := s.ReportInferenceStatusMessages(instance, instance.InstanceName, strconv.FormatInt(instance.DeployInstanceTaskId, 10), strconv.FormatInt(instance.ClusterId, 10), url, true, "") + if err != nil { + logx.Errorf("############ Report Infer Task Status Message Error %s", err.Error()) + } + instance.Status = constants.Running + case "stopped": + if instance.Status == constants.Stopped { + if ch != nil { + <-ch + return + } + return + } + instance.Status = constants.Stopped + case "failed": + if instance.Status == constants.Failed { + if ch != nil { + <-ch + return + } + return + } + err := s.ReportInferenceStatusMessages(instance, instance.InstanceName, strconv.FormatInt(instance.DeployInstanceTaskId, 10), strconv.FormatInt(instance.ClusterId, 10), "", false, ins.Status) + if err != nil { + logx.Errorf("############ Report Infer Task Status Message Error %s", err.Error()) + } + instance.Status = constants.Failed + default: + instance.Status = ins.Status + } + case storeLink.TYPE_SHUGUANGAI: + switch ins.Status { + case "Running": + if instance.Status == constants.Running { + if ch != nil { + <-ch + return + } + return + } + instance.Status = constants.Running + case "Terminated": + if instance.Status == constants.Stopped { + if ch != nil { + <-ch + return + } + return + } + instance.Status = constants.Stopped + default: + instance.Status = ins.Status + } + case storeLink.TYPE_OPENI: + switch ins.Status { + case "RUNNING": + if instance.Status == constants.Running { + if ch != nil { + <-ch + return + } + return + } + url := ins.InferUrl + err := s.ReportInferenceStatusMessages(instance, instance.InstanceName, strconv.FormatInt(instance.DeployInstanceTaskId, 10), strconv.FormatInt(instance.ClusterId, 10), url, true, "") + if err != nil { + logx.Errorf("############ Report Infer Task Status Message Error %s", err.Error()) + } + instance.Status = constants.Running + case "STOPPED": + if instance.Status == constants.Stopped { + if ch != nil { + <-ch + return + } + return + } + instance.Status = constants.Stopped + case "CREATED_FAILED": + if instance.Status == constants.Failed { + if ch != nil { + <-ch + return + } + return + } + err := s.ReportInferenceStatusMessages(instance, instance.InstanceName, strconv.FormatInt(instance.DeployInstanceTaskId, 10), strconv.FormatInt(instance.ClusterId, 10), "", false, ins.Status) + if err != nil { + logx.Errorf("############ Report Infer Task Status Message Error %s", err.Error()) + } + instance.Status = constants.Failed + case "FAILED": + if instance.Status == constants.Failed { + if ch != nil { + <-ch + return + } + return + } + err := s.ReportInferenceStatusMessages(instance, instance.InstanceName, strconv.FormatInt(instance.DeployInstanceTaskId, 10), strconv.FormatInt(instance.ClusterId, 10), "", false, ins.Status) + if err != nil { + logx.Errorf("############ Report Infer Task Status Message Error %s", err.Error()) + } + instance.Status = constants.Failed + default: + instance.Status = ins.Status + } + } + + err = s.aiStorages.UpdateInferDeployInstance(instance, updatetime) + if err != nil { + if ch != nil { + <-ch + return + } + return + } + + if ch != nil { + <-ch + return + } +} + +func (s *SyncInfer) UpdateAutoStoppedInstance() { + list, err := s.aiStorages.GetInferDeployInstanceList() + if err != nil { + return + } + + if len(list) == 0 { + return + } + + s.UpdateDeployInstanceStatusBatch(list, false) +} + +func (s *SyncInfer) CheckStopStatus(in *inference.DeployInstance) bool { + switch in.ClusterType { + case storeLink.TYPE_OCTOPUS: + switch in.Status { + case "stopped": + return true + default: + return false + } + case storeLink.TYPE_MODELARTS: + switch in.Status { + case "stopped": + return true + default: + return false + } + case storeLink.TYPE_SHUGUANGAI: + switch in.Status { + case "Terminated": + return true + default: + return false + } + case storeLink.TYPE_OPENI: + switch in.Status { + case "STOPPED": + return true + default: + return false + } + default: + return false + } +} + +func (s *SyncInfer) CheckRunningStatus(in *inference.DeployInstance) bool { + switch in.ClusterType { + case storeLink.TYPE_OCTOPUS: + switch in.Status { + case "running": + return true + default: + return false + } + case storeLink.TYPE_MODELARTS: + switch in.Status { + case "running": + return true + default: + return false + } + case storeLink.TYPE_SHUGUANGAI: + switch in.Status { + case "Running": + return true + default: + return false + } + case storeLink.TYPE_OPENI: + switch in.Status { + case "RUNNING": + return true + case "WAITING": + return true + default: + return false + } + default: + return false + } +} + +func (s *SyncInfer) ReportInferenceStatusMessages(ins *models.AiInferDeployInstance, taskName string, taskId string, clusterId string, url string, status bool, msg string) error { + var id string + var adapterID string + var clusterID string + var instanceID string + if ins != nil { + id = strconv.FormatInt(ins.Id, 10) + adapterID = strconv.FormatInt(ins.AdapterId, 10) + clusterID = strconv.FormatInt(ins.ClusterId, 10) + instanceID = ins.InstanceId + } + report := &jcs.JobStatusReportReq{} + reportMsg := &jcs.InferReportMessage{ + Type: "Inference", + TaskName: taskName, + TaskID: taskId, + Status: status, + Message: msg, + Url: url, + ID: id, + AdapterID: adapterID, + ClusterID: clusterID, + InstanceID: instanceID, + } + report.Report = reportMsg + + err := jcs.StatusReport(s.config.JcsMiddleware.JobStatusReportUrl, report) + if err != nil { + return err + } + return nil +} diff --git a/internal/scheduler/service/utils/task/tasksync/train.go b/internal/scheduler/service/utils/task/tasksync/train.go new file mode 100644 index 000000000..79e816eae --- /dev/null +++ b/internal/scheduler/service/utils/task/tasksync/train.go @@ -0,0 +1,225 @@ +package tasksync + +import ( + "errors" + "fmt" + "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/config" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/utils/jcs" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "net/http" + "strconv" + "sync" +) + +type SyncTrain struct { + mu sync.Mutex + aiStorages *database.AiStorage + aiCollectorAdapterMap map[string]map[string]collector.AiCollector + config *config.Config +} + +func NewTrainTask(storage *database.AiStorage, aiCollectorAdapterMap map[string]map[string]collector.AiCollector, config *config.Config) *SyncTrain { + return &SyncTrain{ + aiStorages: storage, + aiCollectorAdapterMap: aiCollectorAdapterMap, + config: config, + } +} + +func (s *SyncTrain) UpdateAiTaskStatus(tasklist []*types.TaskModel) { + s.mu.Lock() + defer s.mu.Unlock() + list := make([]*types.TaskModel, len(tasklist)) + copy(list, tasklist) + for i := len(list) - 1; i >= 0; i-- { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + list = append(list[:i], list[i+1:]...) + } + } + + if len(list) == 0 { + return + } + + for _, task := range list { + aiTaskList, err := s.aiStorages.GetAiTaskListById(task.Id) + if err != nil { + logx.Errorf("UpdateAiTaskStatus Get AiTask Error %s", err.Error()) + } + + if len(aiTaskList) == 0 { + continue + } + + go s.updateAiTask(aiTaskList) + } +} + +func (s *SyncTrain) UpdateTaskStatus(tasklist []*types.TaskModel) { + s.mu.Lock() + defer s.mu.Unlock() + list := make([]*types.TaskModel, len(tasklist)) + copy(list, tasklist) + for i := len(list) - 1; i >= 0; i-- { + if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed || list[i].Status == constants.Cancelled { + list = append(list[:i], list[i+1:]...) + } + } + + if len(list) == 0 { + return + } + + for _, task := range list { + aiTask, err := s.aiStorages.GetAiTaskListById(task.Id) + if err != nil { + logx.Errorf("UpdateTaskStatus Get AiTask Error %s", err.Error()) + } + + if len(aiTask) == 0 { + continue + } + + logx.Errorf("############ Report Status Message Before switch %s", task.Status) + if len(aiTask) == 1 { + logx.Errorf("############ Report Status Message Switch %s", aiTask[0].Status) + switch aiTask[0].Status { + + case constants.Completed: + task.Status = constants.Succeeded + logx.Errorf("############ Report Status Message Before Sending %s", task.Status) + + err = s.reportStatusMessages(task, aiTask[0], true) + if err != nil { + logx.Errorf("reportStatusMessages Error %s", err.Error()) + } + case constants.Failed: + task.Status = constants.Failed + logx.Errorf("############ Report Status Message Before Sending %s", task.Status) + + err = s.reportStatusMessages(task, aiTask[0], false) + if err != nil { + logx.Errorf("reportStatusMessages Error %s", err.Error()) + } + + default: + task.Status = aiTask[0].Status + } + + task.StartTime = aiTask[0].StartTime + task.EndTime = aiTask[0].EndTime + err := s.aiStorages.UpdateTask(task) + if err != nil { + logx.Errorf("UpdateTaskStatus Update Task Error %s", err.Error()) + } + } + + } +} + +func (s *SyncTrain) updateAiTask(aiTaskList []*models.TaskAi) { + var wg sync.WaitGroup + for _, aitask := range aiTaskList { + t := aitask + if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" || t.Status == constants.Cancelled { + continue + } + wg.Add(1) + go func() { + h := http.Request{} + trainingTask, err := s.aiCollectorAdapterMap[strconv.FormatInt(t.AdapterId, 10)][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(h.Context(), t.JobId) + if err != nil { + if status.Code(err) == codes.DeadlineExceeded { + msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) + logx.Errorf(errors.New(msg).Error()) + wg.Done() + return + } + + msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) + logx.Errorf(errors.New(msg).Error()) + wg.Done() + return + } + if trainingTask == nil { + wg.Done() + return + } + switch trainingTask.Status { + case constants.Running: + if t.Status != trainingTask.Status { + s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "running", "任务运行中") + t.Status = trainingTask.Status + } + case constants.Failed: + if t.Status != trainingTask.Status { + s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "failed", "任务失败") + t.Status = trainingTask.Status + } + case constants.Completed: + if t.Status != trainingTask.Status { + s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "completed", "任务完成") + t.Status = trainingTask.Status + } + default: + if t.Status != trainingTask.Status { + s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "pending", "任务pending") + t.Status = trainingTask.Status + } + } + t.StartTime = trainingTask.Start + t.EndTime = trainingTask.End + err = s.aiStorages.UpdateAiTask(t) + if err != nil { + msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) + logx.Errorf(errors.New(msg).Error()) + wg.Done() + return + } + wg.Done() + }() + } + wg.Wait() +} + +func (s *SyncTrain) reportStatusMessages(task *types.TaskModel, aiTask *models.TaskAi, status bool) error { + report := &jcs.JobStatusReportReq{} + reportMsg := &jcs.TrainReportMessage{ + Type: "Train", + TaskName: task.Name, + TaskID: strconv.FormatInt(task.Id, 10), + } + + var output string + switch aiTask.ClusterName { + case "openI": + output = aiTask.JobId + case "鹏城云脑II-modelarts": + output = aiTask.Output + } + + reportMsg.Status = status + reportMsg.Message = "" + reportMsg.ClusterID = strconv.FormatInt(aiTask.ClusterId, 10) + reportMsg.Output = output + + report.Report = reportMsg + + err := jcs.StatusReport(s.config.JcsMiddleware.JobStatusReportUrl, report) + if err != nil { + return err + } + + err = jcs.TempSaveReportToTask(s.aiStorages, task, report) + if err != nil { + return err + } + return nil +}