pcm-coordinator/internal/scheduler/service/utils/status/statusSync.go

243 lines
7.3 KiB
Go

package status
import (
"errors"
"fmt"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/config"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/utils/jcs"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"strconv"
"sync"
)
type TaskStatus struct {
taskSyncLock sync.Mutex
aiStorages *database.AiStorage
aiCollectorAdapterMap map[string]map[string]collector.AiCollector
config *config.Config
}
func NewTaskStatus(storage *database.AiStorage, aiCollectorAdapterMap map[string]map[string]collector.AiCollector, config *config.Config) *TaskStatus {
return &TaskStatus{
aiStorages: storage,
aiCollectorAdapterMap: aiCollectorAdapterMap,
config: config,
}
}
func (s *TaskStatus) UpdateAiTaskStatus(tasklist []*types.TaskModel) {
s.taskSyncLock.Lock()
defer s.taskSyncLock.Unlock()
list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed {
list = append(list[:i], list[i+1:]...)
}
}
if len(list) == 0 {
return
}
for _, task := range list {
aiTaskList, err := s.aiStorages.GetAiTaskListById(task.Id)
if err != nil {
logx.Errorf("UpdateAiTaskStatus Get AiTask Error %s", err.Error())
}
if len(aiTaskList) == 0 {
continue
}
go s.updateAiTask(aiTaskList)
}
}
func (s *TaskStatus) UpdateTaskStatus(tasklist []*types.TaskModel) {
s.taskSyncLock.Lock()
defer s.taskSyncLock.Unlock()
list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed || list[i].Status == constants.Cancelled {
list = append(list[:i], list[i+1:]...)
}
}
if len(list) == 0 {
return
}
for _, task := range list {
aiTask, err := s.aiStorages.GetAiTaskListById(task.Id)
if err != nil {
logx.Errorf("UpdateTaskStatus Get AiTask Error %s", err.Error())
}
if len(aiTask) == 0 {
continue
}
logx.Errorf("############ Report Status Message Before switch %s", task.Status)
if len(aiTask) == 1 {
logx.Errorf("############ Report Status Message Switch %s", aiTask[0].Status)
switch aiTask[0].Status {
case constants.Completed:
task.Status = constants.Succeeded
logx.Errorf("############ Report Status Message Before Sending %s", task.Status)
err = s.reportStatusMessages(task, aiTask[0])
if err != nil {
logx.Errorf("reportStatusMessages Error %s", err.Error())
}
case constants.Failed:
task.Status = constants.Failed
logx.Errorf("############ Report Status Message Before Sending %s", task.Status)
err = s.reportStatusMessages(task, aiTask[0])
if err != nil {
logx.Errorf("reportStatusMessages Error %s", err.Error())
}
default:
task.Status = aiTask[0].Status
}
task.StartTime = aiTask[0].StartTime
task.EndTime = aiTask[0].EndTime
err := s.aiStorages.UpdateTask(task)
if err != nil {
logx.Errorf("UpdateTaskStatus Update Task Error %s", err.Error())
}
}
}
}
func (s *TaskStatus) updateAiTask(aiTaskList []*models.TaskAi) {
var wg sync.WaitGroup
for _, aitask := range aiTaskList {
t := aitask
if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" || t.Status == constants.Cancelled {
continue
}
wg.Add(1)
go func() {
h := http.Request{}
trainingTask, err := s.aiCollectorAdapterMap[strconv.FormatInt(t.AdapterId, 10)][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(h.Context(), t.JobId)
if err != nil {
if status.Code(err) == codes.DeadlineExceeded {
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
if trainingTask == nil {
wg.Done()
return
}
switch trainingTask.Status {
case constants.Running:
if t.Status != trainingTask.Status {
s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "running", "任务运行中")
t.Status = trainingTask.Status
}
case constants.Failed:
if t.Status != trainingTask.Status {
s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "failed", "任务失败")
t.Status = trainingTask.Status
}
case constants.Completed:
if t.Status != trainingTask.Status {
s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "completed", "任务完成")
t.Status = trainingTask.Status
}
default:
if t.Status != trainingTask.Status {
s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "pending", "任务pending")
t.Status = trainingTask.Status
}
}
t.StartTime = trainingTask.Start
t.EndTime = trainingTask.End
err = s.aiStorages.UpdateAiTask(t)
if err != nil {
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
wg.Done()
}()
}
wg.Wait()
}
func (s *TaskStatus) reportStatusMessages(task *types.TaskModel, aiTask *models.TaskAi) error {
report := &jcs.TrainReportMessage{
Type: "Train",
TaskName: task.Name,
TaskID: strconv.FormatInt(task.Id, 10),
}
var output string
switch aiTask.ClusterName {
case "openI":
output = aiTask.JobId
case "鹏城云脑II-modelarts":
output = aiTask.Output
}
report.Status = true
report.Message = ""
report.ClusterID = strconv.FormatInt(aiTask.ClusterId, 10)
report.Output = output
err := jcs.StatusReport(s.config.JcsMiddleware.JobStatusReportUrl, report)
if err != nil {
return err
}
err = jcs.TempSaveReportToTask(s.aiStorages, task, report)
if err != nil {
return err
}
return nil
}
func ReportStatus(svc *svc.ServiceContext, taskName string, taskId string, clusterId string, url string, status bool, msg string) error {
report := &jcs.InferReportMessage{
Type: "Inference",
TaskName: taskName,
TaskID: taskId,
Status: status,
Message: msg,
ClusterID: clusterId,
Url: url,
}
err := jcs.StatusReport(svc.Scheduler.AiService.Conf.JcsMiddleware.JobStatusReportUrl, report)
if err != nil {
return err
}
return nil
}