pcm-coordinator/internal/scheduler/service/inference/textInference/textToText.go

125 lines
3.4 KiB
Go

package textInference
import (
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"net/http"
"strconv"
"sync"
"time"
)
const (
CHAT = "chat"
TEXTTOTEXT_AITYPE = "12"
)
type TextToText struct {
opt *option.InferOption
storage *database.AiStorage
inferAdapter map[string]map[string]inference.ICluster
cs []*FilteredCluster
}
func NewTextToText(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) (*TextToText, error) {
cs, err := filterClusters(opt, storage, inferAdapter)
if err != nil {
return nil, err
}
return &TextToText{
opt: opt,
storage: storage,
inferAdapter: inferAdapter,
cs: cs,
}, nil
}
func (tt *TextToText) GetAiType() string {
return TEXTTOTEXT_AITYPE
}
func (tt *TextToText) SaveAiTask(id int64, adapterName string) error {
if len(tt.cs) == 0 {
clusterId := tt.opt.AiClusterIds[0]
clusterName, _ := tt.storage.GetClusterNameById(tt.opt.AiClusterIds[0])
err := tt.storage.SaveAiTask(id, tt.opt, adapterName, clusterId, clusterName, "", constants.Failed, "")
if err != nil {
return err
}
tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "failed", "任务失败")
}
for _, c := range tt.cs {
clusterName, _ := tt.storage.GetClusterNameById(c.clusterId)
err := tt.storage.SaveAiTask(id, tt.opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
if err != nil {
return err
}
}
return nil
}
func filterClusters(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) ([]*FilteredCluster, error) {
var wg sync.WaitGroup
var ch = make(chan *FilteredCluster, len(opt.AiClusterIds))
var cs []*FilteredCluster
inferMap := inferAdapter[opt.AdapterId]
for _, clusterId := range opt.AiClusterIds {
wg.Add(1)
go func(cId string) {
r := http.Request{}
clusterInferUrl, err := inferMap[cId].GetClusterInferUrl(r.Context(), opt)
if err != nil {
wg.Done()
return
}
for i, _ := range clusterInferUrl.InferUrls {
clusterInferUrl.InferUrls[i].Url = clusterInferUrl.InferUrls[i].Url + inference.FORWARD_SLASH + CHAT
}
clusterName, _ := storage.GetClusterNameById(cId)
var f FilteredCluster
f.urls = clusterInferUrl.InferUrls
f.clusterId = cId
f.clusterName = clusterName
ch <- &f
wg.Done()
return
}(clusterId)
}
wg.Wait()
close(ch)
for s := range ch {
cs = append(cs, s)
}
return cs, nil
}
func (tt *TextToText) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error {
for i, t := range aiTaskList {
if strconv.Itoa(int(t.ClusterId)) == tt.cs[i].clusterId {
t.Status = constants.Completed
t.EndTime = time.Now().Format(time.RFC3339)
url := tt.cs[i].urls[0].Url
t.InferUrl = url
err := tt.storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
return err
}
}
}
tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "completed", "任务完成")
return nil
}