288 lines
8.0 KiB
Go
288 lines
8.0 KiB
Go
package api
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/go-playground/validator/v10"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/algorithm"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/dataset"
|
|
aiModel "gitlink.org.cn/JointCloud/pcm-participant-ai/model"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/service"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-ai/task"
|
|
"gitlink.org.cn/JointCloud/pcm-participant-openi/model"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
type aiApi struct {
|
|
*Api
|
|
svcLock sync.RWMutex
|
|
service *service.Service
|
|
}
|
|
|
|
var AiApi = aiApi{
|
|
Api: BaseApi,
|
|
}
|
|
|
|
func (a *aiApi) RegisterSvc(svc *service.Service) {
|
|
a.svcLock.Lock()
|
|
defer a.svcLock.Unlock()
|
|
|
|
a.service = svc
|
|
}
|
|
|
|
// TrainAlgorithmsHandler 处理获取特定平台训练算法的请求
|
|
func (a *aiApi) TrainAlgorithmsHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
algorithms, err := a.service.TrainAlgorithms(c, pfId)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", algorithms)
|
|
}
|
|
|
|
// AllTrainAlgorithmsHandler 处理获取所有训练算法的请求
|
|
func (a *aiApi) AllTrainAlgorithmsHandler(c *gin.Context) {
|
|
algorithms, err := a.service.AllTrainAlgorithms(c)
|
|
if err != nil {
|
|
model.Response(c, http.StatusInternalServerError, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", algorithms)
|
|
}
|
|
|
|
// CreateAlgorithmHandler 处理创建算法的请求
|
|
func (a *aiApi) CreateAlgorithmHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusInternalServerError, err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
var param algorithm.CreateParam
|
|
if err := c.ShouldBindJSON(¶m); err != nil {
|
|
var ve validator.ValidationErrors
|
|
if errors.As(err, &ve) {
|
|
var errorMsg []string
|
|
for _, e := range ve {
|
|
errorMsg = append(errorMsg, fmt.Sprintf("字段 %s 验证失败: %s", e.Field(), e.Tag()))
|
|
}
|
|
model.Response(c, http.StatusBadRequest, "请求体格式错误: "+strings.Join(errorMsg, "; "), err)
|
|
}
|
|
}
|
|
|
|
resp, err := a.service.CreateAlgorithm(c.Request.Context(), pfId, ¶m)
|
|
if err != nil {
|
|
model.Response(c, http.StatusInternalServerError, err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
model.Response(c, http.StatusOK, "success", resp)
|
|
}
|
|
|
|
// CreateDatasetHandler 处理创建数据集的请求
|
|
func (a *aiApi) CreateDatasetHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
var param dataset.CreateParam
|
|
if err := c.ShouldBindJSON(¶m); err != nil {
|
|
var ve validator.ValidationErrors
|
|
if errors.As(err, &ve) {
|
|
var errorMsg []string
|
|
for _, e := range ve {
|
|
errorMsg = append(errorMsg, fmt.Sprintf("字段 %s 验证失败: %s", e.Field(), e.Tag()))
|
|
}
|
|
model.Response(c, http.StatusBadRequest, errorMsg, nil)
|
|
}
|
|
return
|
|
}
|
|
|
|
resp, err := a.service.CreateDataset(c.Request.Context(), pfId, ¶m)
|
|
if err != nil {
|
|
model.Response(c, http.StatusInternalServerError, err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
model.Response(c, http.StatusOK, "success", resp)
|
|
}
|
|
|
|
// CreateModelHandler 处理创建模型的请求
|
|
func (a *aiApi) CreateModelHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
var param aiModel.CreateParam
|
|
if err := c.ShouldBindJSON(¶m); err != nil {
|
|
var ve validator.ValidationErrors
|
|
if errors.As(err, &ve) {
|
|
var errorMsg []string
|
|
for _, e := range ve {
|
|
errorMsg = append(errorMsg, fmt.Sprintf("字段 %s 验证失败: %s", e.Field(), e.Tag()))
|
|
}
|
|
model.Response(c, http.StatusBadRequest, "请求体格式错误: "+strings.Join(errorMsg, "; "), nil)
|
|
}
|
|
return
|
|
}
|
|
|
|
resp, err := a.service.CreateModel(c.Request.Context(), pfId, ¶m)
|
|
if err != nil {
|
|
model.Response(c, http.StatusInternalServerError, err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
model.Response(c, http.StatusOK, "success", resp)
|
|
}
|
|
|
|
// GetResourceSpecsHandler 处理获取资源规格的请求
|
|
func (a *aiApi) GetResourceSpecsHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
rtype := c.Query("rType")
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
_, err = a.service.GetResourceSpecs(c, pfId, rtype)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", err)
|
|
}
|
|
|
|
// TrainResourcesHandler 处理获取特定平台训练资源的请求
|
|
func (a *aiApi) TrainResourcesHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
resources, err := a.service.TrainResources(c, pfId)
|
|
if err != nil {
|
|
model.Response(c, http.StatusInternalServerError, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", resources)
|
|
}
|
|
|
|
// AllTrainResourcesHandler 处理获取所有训练资源的请求
|
|
func (a *aiApi) AllTrainResourcesHandler(c *gin.Context) {
|
|
resources, err := a.service.AllTrainResources(c)
|
|
if err != nil {
|
|
model.Response(c, http.StatusInternalServerError, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", resources)
|
|
}
|
|
|
|
// CreateTrainTaskHandler 处理创建训练任务的请求
|
|
func (a *aiApi) CreateTrainTaskHandler(c *gin.Context) {
|
|
var param service.CreateTrainTaskParam
|
|
if err := c.ShouldBindJSON(¶m); err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
resp, err := a.service.CreateTrainTask(c, ¶m)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", resp)
|
|
}
|
|
|
|
// TaskResultSyncHandler 同步任务结果数据
|
|
func (a *aiApi) TaskResultSyncHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
var param task.ResultSyncParam
|
|
if err := c.ShouldBindJSON(¶m); err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
err = a.service.TaskResultSync(c, pfId, ¶m)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", err)
|
|
}
|
|
|
|
// TaskLogHandler 查询任务日志
|
|
func (a *aiApi) TaskLogHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
taskId := c.Query("taskId")
|
|
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
_, err = a.service.TaskLog(c, pfId, taskId)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", err)
|
|
}
|
|
|
|
// TrainTaskDetailHandler 查询训练任务详情
|
|
func (a *aiApi) TrainTaskDetailHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
taskId := c.Query("taskId")
|
|
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
_, err = a.service.GetTrainingTask(c, pfId, taskId)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", err)
|
|
}
|
|
|
|
// InferTaskDetailHandler 查询推理任务详情
|
|
func (a *aiApi) InferTaskDetailHandler(c *gin.Context) {
|
|
pfIdStr := c.Query("pfId")
|
|
taskId := c.Query("taskId")
|
|
|
|
pfId, err := strconv.ParseInt(pfIdStr, 10, 64)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
_, err = a.service.GetInferenceTask(c, pfId, taskId)
|
|
if err != nil {
|
|
model.Response(c, http.StatusBadRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
model.Response(c, http.StatusOK, "success", err)
|
|
}
|