pcm-coordinator/internal/scheduler/service/updater/taskStatusSync.go

383 lines
9.4 KiB
Go

package updater
import (
"errors"
"fmt"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"strconv"
"sync"
"time"
)
func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed {
list = append(list[:i], list[i+1:]...)
}
}
if len(list) == 0 {
return
}
task := list[0]
for i := range list {
earliest, _ := time.Parse(time.RFC3339, task.UpdatedTime)
latest, _ := time.Parse(time.RFC3339, list[i].UpdatedTime)
if latest.Before(earliest) {
task = list[i]
}
}
// Update Infer Task Status
if task.TaskTypeDict == "11" || task.TaskTypeDict == "12" {
updateInferTaskStatus(svc, *task)
return
}
aiTask, err := svc.Scheduler.AiStorages.GetAiTaskListById(task.Id)
if err != nil {
logx.Errorf(err.Error())
return
}
if len(aiTask) == 0 {
err := svc.Scheduler.AiStorages.UpdateTask(task)
if err != nil {
return
}
return
}
if len(aiTask) == 1 {
if aiTask[0].Status == constants.Completed {
task.Status = constants.Succeeded
} else {
task.Status = aiTask[0].Status
}
task.StartTime = aiTask[0].StartTime
task.EndTime = aiTask[0].EndTime
err := svc.Scheduler.AiStorages.UpdateTask(task)
if err != nil {
return
}
return
}
for i := len(aiTask) - 1; i >= 0; i-- {
if aiTask[i].StartTime == "" {
task.Status = aiTask[i].Status
aiTask = append(aiTask[:i], aiTask[i+1:]...)
}
}
if len(aiTask) == 0 {
err := svc.Scheduler.AiStorages.UpdateTask(task)
if err != nil {
return
}
return
}
start, _ := time.ParseInLocation(constants.Layout, aiTask[0].StartTime, time.Local)
end, _ := time.ParseInLocation(constants.Layout, aiTask[0].EndTime, time.Local)
var status string
var count int
for _, a := range aiTask {
s, _ := time.ParseInLocation(constants.Layout, a.StartTime, time.Local)
e, _ := time.ParseInLocation(constants.Layout, a.EndTime, time.Local)
if s.Before(start) {
start = s
}
if e.After(end) {
end = e
}
if a.Status == constants.Failed {
status = a.Status
break
}
if a.Status == constants.Pending {
status = a.Status
continue
}
if a.Status == constants.Running {
status = a.Status
continue
}
if a.Status == constants.Completed {
count++
continue
}
}
if count == len(aiTask) {
status = constants.Succeeded
}
if status != "" {
task.Status = status
task.StartTime = start.Format(constants.Layout)
task.EndTime = end.Format(constants.Layout)
}
err = svc.Scheduler.AiStorages.UpdateTask(task)
if err != nil {
return
}
}
func updateInferTaskStatus(svc *svc.ServiceContext, task types.TaskModel) {
aiTask, err := svc.Scheduler.AiStorages.GetAiTaskListById(task.Id)
if err != nil {
logx.Errorf(err.Error())
return
}
if len(aiTask) == 0 {
//task.Status = constants.Failed
err = svc.Scheduler.AiStorages.UpdateTask(&task)
if err != nil {
return
}
return
}
if len(aiTask) == 1 {
if aiTask[0].Status == constants.Completed {
task.StartTime = aiTask[0].StartTime
task.EndTime = aiTask[0].EndTime
task.Status = constants.Succeeded
} else {
task.StartTime = aiTask[0].StartTime
task.Status = aiTask[0].Status
}
err = svc.Scheduler.AiStorages.UpdateTask(&task)
if err != nil {
return
}
return
}
//for i := len(aiTask) - 1; i >= 0; i-- {
// if aiTask[i].StartTime == "" {
// task.Status = aiTask[i].Status
// aiTask = append(aiTask[:i], aiTask[i+1:]...)
// }
//}
//
//if len(aiTask) == 0 {
// task.UpdatedTime = time.Now().Format(constants.Layout)
// tx = svc.DbEngin.Table("task").Model(task).Updates(task)
// if tx.Error != nil {
// logx.Errorf(tx.Error.Error())
// return
// }
// return
//}
if aiTask[0].StartTime == "" {
return
}
start, _ := time.ParseInLocation(time.RFC3339, aiTask[0].StartTime, time.Local)
end, _ := time.ParseInLocation(time.RFC3339, aiTask[0].EndTime, time.Local)
var status string
var count int
for _, a := range aiTask {
if a.Status == constants.Failed {
status = a.Status
break
}
if a.Status == constants.Pending {
status = a.Status
continue
}
if a.Status == constants.Running {
status = a.Status
continue
}
if a.Status == constants.Completed {
count++
continue
}
}
if count == len(aiTask) {
status = constants.Succeeded
}
if status == constants.Succeeded {
task.Status = status
task.StartTime = start.Format(time.RFC3339)
task.EndTime = end.Format(time.RFC3339)
} else {
task.Status = status
task.StartTime = start.Format(time.RFC3339)
}
err = svc.Scheduler.AiStorages.UpdateTask(&task)
if err != nil {
return
}
}
func UpdateAiTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed {
list = append(list[:i], list[i+1:]...)
}
}
if len(list) == 0 {
return
}
task := list[0]
for i := range list {
earliest, _ := time.Parse(constants.Layout, task.UpdatedTime)
latest, _ := time.Parse(constants.Layout, list[i].UpdatedTime)
if latest.Before(earliest) {
task = list[i]
}
}
aiTaskList, err := svc.Scheduler.AiStorages.GetAiTaskListById(task.Id)
if err != nil {
logx.Errorf(err.Error())
return
}
if len(aiTaskList) == 0 {
return
}
var wg sync.WaitGroup
for _, aitask := range aiTaskList {
t := aitask
if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" {
continue
}
wg.Add(1)
go func() {
h := http.Request{}
trainingTask, err := svc.Scheduler.AiService.AiCollectorAdapterMap[strconv.FormatInt(t.AdapterId, 10)][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(h.Context(), t.JobId)
if err != nil {
if status.Code(err) == codes.DeadlineExceeded {
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
if trainingTask == nil {
wg.Done()
return
}
switch trainingTask.Status {
case constants.Running:
if t.Status != trainingTask.Status {
svc.Scheduler.AiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "running", "任务运行中")
t.Status = trainingTask.Status
}
case constants.Failed:
if t.Status != trainingTask.Status {
svc.Scheduler.AiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "failed", "任务失败")
t.Status = trainingTask.Status
}
case constants.Completed:
if t.Status != trainingTask.Status {
svc.Scheduler.AiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "completed", "任务完成")
t.Status = trainingTask.Status
}
default:
if t.Status != trainingTask.Status {
svc.Scheduler.AiStorages.AddNoticeInfo(strconv.FormatInt(t.AdapterId, 10), t.AdapterName, strconv.FormatInt(t.ClusterId, 10), t.ClusterName, t.Name, "pending", "任务pending")
t.Status = trainingTask.Status
}
}
t.StartTime = trainingTask.Start
t.EndTime = trainingTask.End
err = svc.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
wg.Done()
}()
}
wg.Wait()
}
func UpdateTrainingTaskStatus(svc *svc.ServiceContext, list []*types.AdapterInfo) {
var wg sync.WaitGroup
for _, adapter := range list {
taskList, err := svc.Scheduler.AiStorages.GetAiTasksByAdapterId(adapter.Id)
if err != nil {
continue
}
if len(taskList) == 0 {
continue
}
for _, task := range taskList {
t := task
if t.Status == constants.Completed || task.Status == constants.Failed || task.Status == constants.Stopped || task.TaskType != "pytorch" {
continue
}
wg.Add(1)
go func() {
h := http.Request{}
trainingTask, err := svc.Scheduler.AiService.AiCollectorAdapterMap[adapter.Id][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(h.Context(), t.JobId)
if err != nil {
msg := fmt.Sprintf("AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error())
logx.Errorf(errors.New(msg).Error())
wg.Done()
return
}
if trainingTask == nil {
wg.Done()
return
}
t.Status = trainingTask.Status
t.StartTime = trainingTask.Start
t.EndTime = trainingTask.End
err = svc.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
wg.Done()
return
}
wg.Done()
}()
}
}
wg.Wait()
return
}