forked from JointCloud/pcm-coordinator
151 lines
4.0 KiB
Go
151 lines
4.0 KiB
Go
package inference
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/storeLink"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type TextToTextInferenceLogic struct {
|
|
logx.Logger
|
|
ctx context.Context
|
|
svcCtx *svc.ServiceContext
|
|
}
|
|
|
|
func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TextToTextInferenceLogic {
|
|
return &TextToTextInferenceLogic{
|
|
Logger: logx.WithContext(ctx),
|
|
ctx: ctx,
|
|
svcCtx: svcCtx,
|
|
}
|
|
}
|
|
|
|
func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) {
|
|
resp = &types.TextToTextInferenceResp{}
|
|
opt := &option.InferOption{
|
|
TaskName: req.TaskName,
|
|
TaskDesc: req.TaskDesc,
|
|
AdapterId: req.AdapterId,
|
|
AiClusterIds: req.AiClusterIds,
|
|
ModelName: req.ModelName,
|
|
ModelType: req.ModelType,
|
|
}
|
|
|
|
_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
|
|
if !ok {
|
|
return nil, errors.New("AdapterId does not exist")
|
|
}
|
|
|
|
//save task
|
|
var synergystatus int64
|
|
var strategyCode int64
|
|
adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
var cluster_ch = make(chan struct {
|
|
urls []*inference.InferUrl
|
|
clusterId string
|
|
clusterName string
|
|
}, len(opt.AiClusterIds))
|
|
|
|
var cs []struct {
|
|
urls []*inference.InferUrl
|
|
clusterId string
|
|
clusterName string
|
|
}
|
|
inferMap := l.svcCtx.Scheduler.AiService.InferenceAdapterMap[opt.AdapterId]
|
|
|
|
//save taskai
|
|
for _, clusterId := range opt.AiClusterIds {
|
|
wg.Add(1)
|
|
go func(cId string) {
|
|
urls, err := inferMap[cId].GetInferUrl(l.ctx, opt)
|
|
if err != nil {
|
|
wg.Done()
|
|
return
|
|
}
|
|
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId)
|
|
|
|
s := struct {
|
|
urls []*inference.InferUrl
|
|
clusterId string
|
|
clusterName string
|
|
}{
|
|
urls: urls,
|
|
clusterId: cId,
|
|
clusterName: clusterName,
|
|
}
|
|
|
|
cluster_ch <- s
|
|
wg.Done()
|
|
return
|
|
}(clusterId)
|
|
}
|
|
wg.Wait()
|
|
close(cluster_ch)
|
|
|
|
for s := range cluster_ch {
|
|
cs = append(cs, s)
|
|
}
|
|
|
|
if len(cs) == 0 {
|
|
clusterId := opt.AiClusterIds[0]
|
|
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(opt.AiClusterIds[0])
|
|
err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, clusterId, clusterName, "", constants.Failed, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
|
|
}
|
|
|
|
for _, c := range cs {
|
|
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId)
|
|
err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
var aiTaskList []*models.TaskAi
|
|
tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
|
|
if tx.Error != nil {
|
|
return nil, tx.Error
|
|
|
|
}
|
|
|
|
for i, t := range aiTaskList {
|
|
if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId {
|
|
t.Status = constants.Completed
|
|
t.EndTime = time.Now().Format(time.RFC3339)
|
|
url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat"
|
|
t.InferUrl = url
|
|
err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
|
|
if err != nil {
|
|
logx.Errorf(tx.Error.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
|
|
|
|
return resp, nil
|
|
}
|