forked from JointCloud/pcm-coordinator
156 lines
4.3 KiB
Go
156 lines
4.3 KiB
Go
package textInference
|
|
|
|
import (
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
CHAT = "chat"
|
|
TEXTTOTEXT_AITYPE = "12"
|
|
)
|
|
|
|
type TextToText struct {
|
|
opt *option.InferOption
|
|
storage *database.AiStorage
|
|
inferAdapter map[string]map[string]inference.ICluster
|
|
instance *models.AiInferDeployInstance
|
|
cs []*FilteredCluster
|
|
}
|
|
|
|
func NewTextToText(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster, instance *models.AiInferDeployInstance) (*TextToText, error) {
|
|
cs, err := filterClusters(inferAdapter, instance)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &TextToText{
|
|
opt: opt,
|
|
storage: storage,
|
|
inferAdapter: inferAdapter,
|
|
cs: cs,
|
|
}, nil
|
|
}
|
|
|
|
func (tt *TextToText) GetAiType() string {
|
|
return TEXTTOTEXT_AITYPE
|
|
}
|
|
|
|
func (tt *TextToText) SaveAiTask(id int64, adapterName string) error {
|
|
|
|
if len(tt.cs) == 0 {
|
|
clusterId := tt.opt.AiClusterIds[0]
|
|
clusterName, _ := tt.storage.GetClusterNameById(tt.opt.AiClusterIds[0])
|
|
err := tt.storage.SaveAiTask(id, tt.opt, adapterName, clusterId, clusterName, "", constants.Failed, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "failed", "任务失败")
|
|
}
|
|
|
|
for _, c := range tt.cs {
|
|
clusterName, _ := tt.storage.GetClusterNameById(c.clusterId)
|
|
err := tt.storage.SaveAiTask(id, tt.opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func filterClusters(inferAdapter map[string]map[string]inference.ICluster, instance *models.AiInferDeployInstance) ([]*FilteredCluster, error) {
|
|
var cs []*FilteredCluster
|
|
var inferurls []*inference.InferUrl
|
|
clusterId := strconv.FormatInt(instance.ClusterId, 10)
|
|
adapterId := strconv.FormatInt(instance.AdapterId, 10)
|
|
r := http.Request{}
|
|
deployInstance, err := inferAdapter[adapterId][clusterId].GetInferDeployInstance(r.Context(), instance.InstanceId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var url inference.InferUrl
|
|
url.Url = deployInstance.InferUrl + inference.FORWARD_SLASH + CHAT
|
|
url.Card = deployInstance.InferCard
|
|
inferurls = append(inferurls, &url)
|
|
|
|
clusterType := deployInstance.ClusterType
|
|
clusterName := deployInstance.ClusterName
|
|
|
|
var f FilteredCluster
|
|
f.urls = inferurls
|
|
f.clusterId = clusterId
|
|
f.clusterName = clusterName
|
|
f.clusterType = clusterType
|
|
cs = append(cs, &f)
|
|
|
|
return cs, nil
|
|
}
|
|
|
|
func filterClustersTemp(opt *option.InferOption, storage *database.AiStorage, inferAdapter map[string]map[string]inference.ICluster) ([]*FilteredCluster, error) {
|
|
var wg sync.WaitGroup
|
|
var ch = make(chan *FilteredCluster, len(opt.AiClusterIds))
|
|
var cs []*FilteredCluster
|
|
inferMap := inferAdapter[opt.AdapterId]
|
|
|
|
for _, clusterId := range opt.AiClusterIds {
|
|
wg.Add(1)
|
|
go func(cId string) {
|
|
r := http.Request{}
|
|
clusterInferUrl, err := inferMap[cId].GetClusterInferUrl(r.Context(), opt)
|
|
if err != nil {
|
|
wg.Done()
|
|
return
|
|
}
|
|
|
|
for i, _ := range clusterInferUrl.InferUrls {
|
|
clusterInferUrl.InferUrls[i].Url = clusterInferUrl.InferUrls[i].Url + inference.FORWARD_SLASH + CHAT
|
|
}
|
|
|
|
clusterName, _ := storage.GetClusterNameById(cId)
|
|
|
|
var f FilteredCluster
|
|
f.urls = clusterInferUrl.InferUrls
|
|
f.clusterId = cId
|
|
f.clusterName = clusterName
|
|
|
|
ch <- &f
|
|
wg.Done()
|
|
return
|
|
}(clusterId)
|
|
}
|
|
wg.Wait()
|
|
close(ch)
|
|
|
|
for s := range ch {
|
|
cs = append(cs, s)
|
|
}
|
|
|
|
return cs, nil
|
|
}
|
|
|
|
func (tt *TextToText) UpdateStatus(aiTaskList []*models.TaskAi, adapterName string) error {
|
|
for i, t := range aiTaskList {
|
|
if strconv.Itoa(int(t.ClusterId)) == tt.cs[i].clusterId {
|
|
t.Status = constants.Completed
|
|
t.EndTime = time.Now().Format(time.RFC3339)
|
|
url := tt.cs[i].urls[0].Url
|
|
t.InferUrl = url
|
|
err := tt.storage.UpdateAiTask(t)
|
|
if err != nil {
|
|
logx.Errorf(err.Error())
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
tt.storage.AddNoticeInfo(tt.opt.AdapterId, adapterName, "", "", tt.opt.TaskName, "completed", "任务完成")
|
|
return nil
|
|
}
|