forked from JointCloud/pcm-coordinator
125 lines
3.4 KiB
Go
125 lines
3.4 KiB
Go
package textInference
|
|
|
|
import (
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
CHAT = "chat"
|
|
TEXTTOTEXT_AITYPE = "12"
|
|
)
|
|
|
|
type TextToText struct {
|
|
opt *option.InferOption
|
|
storage *database.AiStorage
|
|
inferAdapter map[string]map[string]inference.ICluster
|
|
cs []*FilteredCluster
|
|
}
|
|
|
|
func NewTextToText(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) (*TextToText, error) {
|
|
cs, err := filterClusters(opt, storage, inferAdapter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &TextToText{
|
|
opt: opt,
|
|
storage: storage,
|
|
inferAdapter: inferAdapter,
|
|
cs: cs,
|
|
}, nil
|
|
}
|
|
|
|
func (tt *TextToText) GetAiType() string {
|
|
return TEXTTOTEXT_AITYPE
|
|
}
|
|
|
|
func (tt *TextToText) SaveAiTask(id int64, adapterName string) error {
|
|
|
|
if len(tt.cs) == 0 {
|
|
clusterId := tt.opt.AiClusterIds[0]
|
|
clusterName, _ := tt.storage.GetClusterNameById(tt.opt.AiClusterIds[0])
|
|
err := tt.storage.SaveAiTask(id, tt.opt, adapterName, clusterId, clusterName, "", constants.Failed, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "failed", "任务失败")
|
|
}
|
|
|
|
for _, c := range tt.cs {
|
|
clusterName, _ := tt.storage.GetClusterNameById(c.clusterId)
|
|
err := tt.storage.SaveAiTask(id, tt.opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func filterClusters(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) ([]*FilteredCluster, error) {
|
|
var wg sync.WaitGroup
|
|
var ch = make(chan *FilteredCluster, len(opt.AiClusterIds))
|
|
var cs []*FilteredCluster
|
|
inferMap := inferAdapter[opt.AdapterId]
|
|
|
|
for _, clusterId := range opt.AiClusterIds {
|
|
wg.Add(1)
|
|
go func(cId string) {
|
|
r := http.Request{}
|
|
clusterInferUrl, err := inferMap[cId].GetClusterInferUrl(r.Context(), opt)
|
|
if err != nil {
|
|
wg.Done()
|
|
return
|
|
}
|
|
for i, _ := range clusterInferUrl.InferUrls {
|
|
clusterInferUrl.InferUrls[i].Url = clusterInferUrl.InferUrls[i].Url + inference.FORWARD_SLASH + CHAT
|
|
}
|
|
clusterName, _ := storage.GetClusterNameById(cId)
|
|
|
|
var f FilteredCluster
|
|
f.urls = clusterInferUrl.InferUrls
|
|
f.clusterId = cId
|
|
f.clusterName = clusterName
|
|
|
|
ch <- &f
|
|
wg.Done()
|
|
return
|
|
}(clusterId)
|
|
}
|
|
wg.Wait()
|
|
close(ch)
|
|
|
|
for s := range ch {
|
|
cs = append(cs, s)
|
|
}
|
|
|
|
return cs, nil
|
|
}
|
|
|
|
func (tt *TextToText) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error {
|
|
for i, t := range aiTaskList {
|
|
if strconv.Itoa(int(t.ClusterId)) == tt.cs[i].clusterId {
|
|
t.Status = constants.Completed
|
|
t.EndTime = time.Now().Format(time.RFC3339)
|
|
url := tt.cs[i].urls[0].Url
|
|
t.InferUrl = url
|
|
err := tt.storage.UpdateAiTask(t)
|
|
if err != nil {
|
|
logx.Errorf(err.Error())
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "completed", "任务完成")
|
|
return nil
|
|
}
|