pcm-coordinator/internal/scheduler/service/utils/task/tasksync/train.go

225 lines
6.7 KiB
Go

package tasksync
import (
"errors"
"fmt"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/config"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/utils/jcs"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"strconv"
"sync"
)
type SyncTrain struct {
mu sync.Mutex
aiStorages *database.AiStorage
aiCollectorAdapterMap map[string]map[string]collector.AiCollector
config *config.Config
}
func NewTrainTask(storage *database.AiStorage, aiCollectorAdapterMap map[string]map[string]collector.AiCollector, config *config.Config) *SyncTrain {
return &SyncTrain{
aiStorages: storage,
aiCollectorAdapterMap: aiCollectorAdapterMap,
config: config,
}
}
func (s *SyncTrain) UpdateAiTaskStatus(tasklist []*types.TaskModel) {
s.mu.Lock()
defer s.mu.Unlock()
list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed {
list = append(list[:i], list[i+1:]...)
}
}
if len(list) == 0 {
return
}
buffer := make(chan bool, 10)
for _, task := range list {
aiTaskList, err := s.aiStorages.GetAiTaskListById(task.Id)
if err != nil {
logx.Errorf("UpdateAiTaskStatus Get AiTask Error %s", err.Error())
}
if len(aiTaskList) == 0 {
continue
}
buffer <- true
go s.updateAiTask(aiTaskList, buffer)
}
}
func (s *SyncTrain) UpdateTaskStatus(tasklist []*types.TaskModel) {
s.mu.Lock()
defer s.mu.Unlock()
list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed || list[i].Status == constants.Cancelled {
list = append(list[:i], list[i+1:]...)
}
}
if len(list) == 0 {
return
}
for _, task := range list {
aiTask, err := s.aiStorages.GetAiTaskListById(task.Id)
if err != nil {
logx.Errorf("UpdateTaskStatus Get AiTask Error %s", err.Error())
}
if len(aiTask) == 0 {
continue
}
if len(aiTask) == 1 {
switch aiTask[0].Status {
case constants.Completed:
task.Status = constants.Succeeded
logx.Errorf("############ Report Status Message: %s", task.Status)
err = s.reportStatusMessages(task, aiTask[0], true)
if err != nil {
logx.Errorf("reportStatusMessages Error %s", err.Error())
}
case constants.Failed:
task.Status = constants.Failed
logx.Errorf("############ Report Status Message: %s", task.Status)
err = s.reportStatusMessages(task, aiTask[0], false)
if err != nil {
logx.Errorf("reportStatusMessages Error %s", err.Error())
}
default:
task.Status = aiTask[0].Status
}
task.StartTime = aiTask[0].StartTime
task.EndTime = aiTask[0].EndTime
err := s.aiStorages.UpdateTask(task)
if err != nil {
logx.Errorf("UpdateTaskStatus Update Task Error %s", err.Error())
}
}
}
}
func (s *SyncTrain) updateAiTask(aiTaskList []*models.TaskAi, ch chan bool) {
var wg sync.WaitGroup
for _, aitask := range aiTaskList {
t := aitask
if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" || t.Status == constants.Cancelled {
continue
}
wg.Add(1)
go func() {
h := http.Request{}
trainingTask, err := s.aiCollectorAdapterMap[strconv.FormatInt(t.AdapterId, 10)][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(h.Context(), t.JobId)
if err != nil {
if status.Code(err) == codes.DeadlineExceeded {
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
if trainingTask == nil {
wg.Done()
return
}
switch trainingTask.Status {
case constants.Running:
if t.Status != trainingTask.Status {
s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "running", "任务运行中")
t.Status = trainingTask.Status
}
case constants.Failed:
if t.Status != trainingTask.Status {
s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "failed", "任务失败")
t.Status = trainingTask.Status
}
case constants.Completed:
if t.Status != trainingTask.Status {
s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "completed", "任务完成")
t.Status = trainingTask.Status
}
default:
if t.Status != trainingTask.Status {
s.aiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "pending", "任务pending")
t.Status = trainingTask.Status
}
}
t.StartTime = trainingTask.Start
t.EndTime = trainingTask.End
err = s.aiStorages.UpdateAiTask(t)
if err != nil {
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
wg.Done()
}()
}
wg.Wait()
<-ch
}
func (s *SyncTrain) reportStatusMessages(task *types.TaskModel, aiTask *models.TaskAi, status bool) error {
report := &jcs.JobStatusReportReq{}
reportMsg := &jcs.TrainReportMessage{
Type: "Train",
TaskName: task.Name,
TaskID: strconv.FormatInt(task.Id, 10),
}
var output string
switch aiTask.ClusterName {
case "openI":
output = aiTask.JobId
case "鹏城云脑II-modelarts":
output = aiTask.Output
}
reportMsg.Status = status
reportMsg.Message = ""
reportMsg.ClusterID = strconv.FormatInt(aiTask.ClusterId, 10)
reportMsg.Output = output
report.Report = reportMsg
err := jcs.StatusReport(s.config.JcsMiddleware.JobStatusReportUrl, report)
if err != nil {
return err
}
err = jcs.TempSaveReportToTask(s.aiStorages, task, report)
if err != nil {
return err
}
return nil
}