pcm-participant/client/api/ai.go

212 lines
6.4 KiB
Go

package api
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"gitlink.org.cn/JointCloud/pcm-participant-ai/algorithm"
"gitlink.org.cn/JointCloud/pcm-participant-ai/dataset"
aiModel "gitlink.org.cn/JointCloud/pcm-participant-ai/model"
"gitlink.org.cn/JointCloud/pcm-participant-ai/service"
"gitlink.org.cn/JointCloud/pcm-participant-openi/model"
"net/http"
"strconv"
"strings"
"sync"
)
type aiApi struct {
*Api
svcLock sync.RWMutex
service *service.Service
}
var AiApi = aiApi{
Api: BaseApi,
}
func (a *aiApi) RegisterSvc(svc *service.Service) {
a.svcLock.Lock()
defer a.svcLock.Unlock()
a.service = svc
}
// TrainAlgorithmsHandler 处理获取特定平台训练算法的请求
func (a *aiApi) TrainAlgorithmsHandler(c *gin.Context) {
pfIdStr := c.Query("pfId")
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
if err != nil {
model.Response(c, http.StatusBadRequest, "invalid pfId", nil)
return
}
algorithms, err := a.service.TrainAlgorithms(c, pfId)
if err != nil {
model.Response(c, http.StatusInternalServerError, "failed to get train algorithms", nil)
return
}
model.Response(c, http.StatusOK, "success", algorithms)
}
// AllTrainAlgorithmsHandler 处理获取所有训练算法的请求
func (a *aiApi) AllTrainAlgorithmsHandler(c *gin.Context) {
algorithms, err := a.service.AllTrainAlgorithms(c)
if err != nil {
model.Response(c, http.StatusInternalServerError, "failed to get all train algorithms", nil)
return
}
model.Response(c, http.StatusOK, "success", algorithms)
}
// CreateAlgorithmHandler 处理创建算法的请求
func (a *aiApi) CreateAlgorithmHandler(c *gin.Context) {
pfIdStr := c.Query("pfId")
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
if err != nil {
model.Response(c, http.StatusBadRequest, "invalid pfId", nil)
return
}
var param algorithm.CreateParam
if err := c.ShouldBindJSON(&param); err != nil {
if ve, ok := err.(validator.ValidationErrors); ok {
var errorMsg []string
for _, e := range ve {
errorMsg = append(errorMsg, fmt.Sprintf("字段 %s 验证失败: %s", e.Field(), e.Tag()))
}
model.Response(c, http.StatusBadRequest, "请求体格式错误: "+strings.Join(errorMsg, "; "), nil)
} else {
model.Response(c, http.StatusBadRequest, "请求体解析失败: "+err.Error(), nil)
}
return
//model.Response(c, http.StatusBadRequest, "invalid request body", nil)
//return
}
resp, err := a.service.CreateAlgorithm(c.Request.Context(), pfId, &param)
if err != nil {
model.Response(c, http.StatusInternalServerError, "failed to create algorithm", nil)
return
}
model.Response(c, http.StatusOK, "success", resp)
}
// CreateDatasetHandler 处理创建数据集的请求
func (a *aiApi) CreateDatasetHandler(c *gin.Context) {
pfIdStr := c.Query("pfId")
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
if err != nil {
model.Response(c, http.StatusBadRequest, "invalid pfId", nil)
return
}
var param dataset.CreateParam
if err := c.ShouldBindJSON(&param); err != nil {
if ve, ok := err.(validator.ValidationErrors); ok {
var errorMsg []string
for _, e := range ve {
errorMsg = append(errorMsg, fmt.Sprintf("字段 %s 验证失败: %s", e.Field(), e.Tag()))
}
model.Response(c, http.StatusBadRequest, "请求体格式错误: "+strings.Join(errorMsg, "; "), nil)
} else {
model.Response(c, http.StatusBadRequest, "请求体解析失败: "+err.Error(), nil)
}
return
}
resp, err := a.service.CreateDataset(c.Request.Context(), pfId, &param)
if err != nil {
model.Response(c, http.StatusInternalServerError, "failed to create algorithm", nil)
return
}
model.Response(c, http.StatusOK, "success", resp)
}
// CreateModelHandler 处理创建模型的请求
func (a *aiApi) CreateModelHandler(c *gin.Context) {
pfIdStr := c.Query("pfId")
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
if err != nil {
model.Response(c, http.StatusBadRequest, "invalid pfId", nil)
return
}
var param aiModel.CreateParam
if err := c.ShouldBindJSON(&param); err != nil {
if ve, ok := err.(validator.ValidationErrors); ok {
var errorMsg []string
for _, e := range ve {
errorMsg = append(errorMsg, fmt.Sprintf("字段 %s 验证失败: %s", e.Field(), e.Tag()))
}
model.Response(c, http.StatusBadRequest, "请求体格式错误: "+strings.Join(errorMsg, "; "), nil)
} else {
model.Response(c, http.StatusBadRequest, "请求体解析失败: "+err.Error(), nil)
}
return
}
resp, err := a.service.CreateModel(c.Request.Context(), pfId, &param)
if err != nil {
model.Response(c, http.StatusInternalServerError, "failed to create algorithm", nil)
return
}
model.Response(c, http.StatusOK, "success", resp)
}
// GetResourceSpecsHandler 处理获取资源规格的请求
func (a *aiApi) GetResourceSpecsHandler(c *gin.Context) {
pfIdStr := c.Query("pfId")
rtype := c.Query("rType")
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
if err != nil {
model.Response(c, http.StatusBadRequest, "invalid pfId", nil)
return
}
a.service.GetResourceSpecs(c, pfId, rtype)
model.Response(c, http.StatusOK, "success", nil)
}
// TrainResourcesHandler 处理获取特定平台训练资源的请求
func (a *aiApi) TrainResourcesHandler(c *gin.Context) {
pfIdStr := c.Query("pfId")
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
if err != nil {
model.Response(c, http.StatusBadRequest, "invalid pfId", nil)
return
}
resources, err := a.service.TrainResources(c, pfId)
if err != nil {
model.Response(c, http.StatusInternalServerError, "failed to get train resources", nil)
return
}
model.Response(c, http.StatusOK, "success", resources)
}
// AllTrainResourcesHandler 处理获取所有训练资源的请求
func (a *aiApi) AllTrainResourcesHandler(c *gin.Context) {
resources, err := a.service.AllTrainResources(c)
if err != nil {
model.Response(c, http.StatusInternalServerError, "failed to get all train resources", nil)
return
}
model.Response(c, http.StatusOK, "success", resources)
}
// CreateTrainTaskHandler 处理创建训练任务的请求
func (a *aiApi) CreateTrainTaskHandler(c *gin.Context) {
var param service.CreateTrainTaskParam
if err := c.ShouldBindJSON(&param); err != nil {
model.Response(c, http.StatusBadRequest, "invalid request body", nil)
return
}
resp, err := a.service.CreateTrainTask(c, &param)
if err != nil {
model.Response(c, http.StatusBadRequest, "failed to create train task", err)
return
}
model.Response(c, http.StatusOK, "success", resp)
}