forked from JointCloud/pcm-coordinator
95 lines
2.8 KiB
Go
95 lines
2.8 KiB
Go
package ai
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/tls"
|
||
"github.com/go-resty/resty/v2"
|
||
"github.com/pkg/errors"
|
||
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
||
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
||
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils/hws"
|
||
"k8s.io/apimachinery/pkg/util/json"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
|
||
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
|
||
|
||
"github.com/zeromicro/go-zero/core/logx"
|
||
)
|
||
|
||
type ChatLogic struct {
|
||
logx.Logger
|
||
ctx context.Context
|
||
svcCtx *svc.ServiceContext
|
||
}
|
||
|
||
func NewChatLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatLogic {
|
||
return &ChatLogic{
|
||
Logger: logx.WithContext(ctx),
|
||
ctx: ctx,
|
||
svcCtx: svcCtx,
|
||
}
|
||
}
|
||
|
||
func (l *ChatLogic) Chat(req *types.ChatReq) (resp *types.ChatResult, err error) {
|
||
resp = &types.ChatResult{}
|
||
jsonBytes, err := json.Marshal(&req.ReqData)
|
||
if err != nil {
|
||
logx.Errorf("【序列化请求数据失败: %v】", err)
|
||
return nil, errors.New("请求数据序列化失败")
|
||
}
|
||
|
||
taskAi := models.TaskAi{}
|
||
l.svcCtx.DbEngin.Model(models.TaskAi{}).Where("id", req.Id).Scan(&taskAi)
|
||
logx.Infof("【开始处理请求,目标URL: %s】", taskAi.InferUrl)
|
||
|
||
// 构建 HTTP 请求
|
||
request, err := http.NewRequest("POST", taskAi.InferUrl, bytes.NewBuffer(jsonBytes))
|
||
if err != nil {
|
||
logx.Errorf("【构建 HTTP 请求失败: %v】", err)
|
||
return nil, errors.New("网络错误,请稍后重试")
|
||
}
|
||
client := resty.New().SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
|
||
restReq := client.R()
|
||
|
||
//ModelArts
|
||
cluster := models.CloudModel{}
|
||
l.svcCtx.DbEngin.Table("t_cluster").Where("id", taskAi.ClusterId).Scan(&cluster)
|
||
if strings.EqualFold(cluster.Label, constants.MODELARTS) {
|
||
signer := &hws.Signer{
|
||
Key: cluster.Ak,
|
||
Secret: cluster.Sk,
|
||
}
|
||
if err := signer.Sign(request); err != nil {
|
||
logx.Errorf("【接口签名错误: %v】", err)
|
||
return nil, errors.New("网络错误,请稍后重试")
|
||
}
|
||
restReq.
|
||
SetHeader("X-Project-Id", cluster.ProjectId).
|
||
SetHeader("x-stage", "RELEASE").
|
||
SetHeader("Authorization", request.Header.Get(hws.HeaderXAuthorization)).
|
||
SetHeader("X-Sdk-Date", request.Header.Get(hws.HeaderXDateTime))
|
||
}
|
||
|
||
response, err := restReq.
|
||
SetHeader("Content-Type", "application/json").
|
||
SetBody(jsonBytes).
|
||
SetResult(&resp).
|
||
Post(taskAi.InferUrl)
|
||
|
||
if err != nil {
|
||
logx.Errorf("【远程调用接口URL:%s, 返回错误: %s】", taskAi.InferUrl, err.Error())
|
||
return nil, errors.New("网络错误,请稍后重试")
|
||
}
|
||
|
||
if response.StatusCode() != 200 {
|
||
logx.Errorf("【远程调用接口URL:%s, 返回错误: %s】", taskAi.InferUrl, response.Body())
|
||
return nil, errors.New("网络错误,请稍后重试")
|
||
}
|
||
|
||
logx.Infof("【请求处理成功,目标URL: %s】", taskAi.InferUrl)
|
||
return resp, nil
|
||
}
|