added textinfer api

Former-commit-id: bfdce90251
This commit is contained in:
tzwang 2024-06-25 18:19:18 +08:00
parent 60012ab0fb
commit e6b9d3d23b
9 changed files with 140 additions and 31 deletions

View File

@ -1,28 +1,25 @@
package inference
import (
"net/http"
"github.com/zeromicro/go-zero/rest/httpx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"net/http"
)
func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.TextToTextInferenceReq
if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
result.ParamErrorResult(r, w, err)
return
}
l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx)
resp, err := l.TextToTextInference(&req)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
} else {
httpx.OkJsonCtx(r.Context(), w, resp)
}
result.HttpResult(r, w, resp, err)
}
}

View File

@ -133,14 +133,14 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
var wg sync.WaitGroup
var cluster_ch = make(chan struct {
urls []*collector.ImageInferUrl
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
}, len(clusters))
var cs []struct {
urls []*collector.ImageInferUrl
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
@ -182,7 +182,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
wg.Add(1)
c := cluster
go func() {
imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt)
imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt)
if err != nil {
wg.Done()
return
@ -190,7 +190,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
s := struct {
urls []*collector.ImageInferUrl
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
@ -373,7 +373,7 @@ func sendInferReq(images []struct {
imageResult *types.ImageResult
file multipart.File
}, cluster struct {
urls []*collector.ImageInferUrl
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
@ -384,7 +384,7 @@ func sendInferReq(images []struct {
imageResult *types.ImageResult
file multipart.File
}, c struct {
urls []*collector.ImageInferUrl
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
@ -494,7 +494,7 @@ type Res struct {
}
func contains(cs []struct {
urls []*collector.ImageInferUrl
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32

View File

@ -2,11 +2,18 @@ package inference
import (
"context"
"errors"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"strconv"
"sync"
"time"
)
type TextToTextInferenceLogic struct {
@ -24,7 +31,110 @@ func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext
}
func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) {
// todo: add your logic here and delete this line
resp = &types.TextToTextInferenceResp{}
opt := &option.InferOption{
TaskName: req.TaskName,
TaskDesc: req.TaskDesc,
AdapterId: req.AdapterId,
AiClusterIds: req.AiClusterIds,
ModelName: req.ModelName,
ModelType: req.ModelType,
}
return
_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
if !ok {
return nil, errors.New("AdapterId does not exist")
}
//save task
var synergystatus int64
var strategyCode int64
adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
if err != nil {
return nil, err
}
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12")
if err != nil {
return nil, err
}
var wg sync.WaitGroup
var cluster_ch = make(chan struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}, len(opt.AiClusterIds))
var cs []struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}
collectorMap := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
//save taskai
for _, clusterId := range opt.AiClusterIds {
wg.Add(1)
go func(cId string) {
urls, err := collectorMap[cId].GetInferUrl(l.ctx, opt)
if err != nil {
wg.Done()
return
}
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId)
s := struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}{
urls: urls,
clusterId: cId,
clusterName: clusterName,
}
cluster_ch <- s
wg.Done()
return
}(clusterId)
}
wg.Wait()
close(cluster_ch)
for s := range cluster_ch {
cs = append(cs, s)
}
for _, c := range cs {
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId)
err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
if err != nil {
return nil, err
}
}
var aiTaskList []*models.TaskAi
tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
if tx.Error != nil {
return nil, tx.Error
}
for i, t := range aiTaskList {
if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId {
t.Status = constants.Completed
t.EndTime = time.Now().Format(time.RFC3339)
url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat"
t.InferUrl = url
err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
return resp, nil
}

View File

@ -15,10 +15,10 @@ type AiCollector interface {
UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error
GetComputeCards(ctx context.Context) ([]string, error)
GetUserBalance(ctx context.Context) (float64, error)
GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*ImageInferUrl, error)
GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error)
}
type ImageInferUrl struct {
type InferUrl struct {
Url string
Card string
}

View File

@ -378,8 +378,8 @@ func (m *ModelArtsLink) generateAlgorithmId(ctx context.Context, option *option.
return errors.New("failed to get AlgorithmId")
}
func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) {
var imageUrls []*collector.ImageInferUrl
func (m *ModelArtsLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
var imageUrls []*collector.InferUrl
urlReq := &modelartsclient.ImageReasoningUrlReq{
ModelName: option.ModelName,
Type: option.ModelType,
@ -389,7 +389,7 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf
if err != nil {
return nil, err
}
imageUrl := &collector.ImageInferUrl{
imageUrl := &collector.InferUrl{
Url: urlResp.Url,
Card: "npu",
}

View File

@ -871,7 +871,7 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec
return errors.New("set ResourceId error")
}
func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) {
func (o *OctopusLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
req := &octopus.GetNotebookListReq{
Platform: o.platform,
PageIndex: o.pageIndex,
@ -882,12 +882,12 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer
return nil, err
}
var imageUrls []*collector.ImageInferUrl
var imageUrls []*collector.InferUrl
for _, notebook := range list.Payload.GetNotebooks() {
if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" {
url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1)
names := strings.Split(notebook.AlgorithmName, UNDERSCORE)
imageUrl := &collector.ImageInferUrl{
imageUrl := &collector.InferUrl{
Url: DOMAIN + url + FORWARD_SLASH + "image",
Card: names[2],
}

View File

@ -730,8 +730,8 @@ func (s *ShuguangAi) generateParams(option *option.AiOption) error {
return errors.New("failed to set params")
}
func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) {
var imageUrls []*collector.ImageInferUrl
func (s *ShuguangAi) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
var imageUrls []*collector.InferUrl
urlReq := &hpcAC.GetInferUrlReq{
ModelName: option.ModelName,
@ -743,7 +743,7 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO
if err != nil {
return nil, err
}
imageUrl := &collector.ImageInferUrl{
imageUrl := &collector.InferUrl{
Url: urlResp.Url,
Card: "dcu",
}

View File

@ -78,6 +78,7 @@ var (
}
ModelTypeMap = map[string][]string{
"image_recognition": {"imagenet_resnet50"},
"text_to_text": {"chatGLM-6B"},
}
AITYPE = map[string]string{
"1": OCTOPUS,

View File

@ -54,6 +54,7 @@ type (
TaskType string `db:"task_type"`
DeletedAt *time.Time `db:"deleted_at"`
Card string `db:"card"`
InferUrl string `db:"infer_url"`
}
)