pcm-openi/service/model.go

252 lines
7.5 KiB
Go

package service
import (
"errors"
"fmt"
"github.com/go-resty/resty/v2"
json "github.com/json-iterator/go"
"gitlink.org.cn/JointCloud/pcm-openi/common"
"gitlink.org.cn/JointCloud/pcm-openi/model"
"io"
"math"
"mime/multipart"
"net/http"
"strconv"
)
type ModelService struct {
}
func NewModelService() *ModelService {
return &ModelService{}
}
func (r ModelService) CreateLocalModel(token string, param model.CreateLocalModelParam) (resp model.CreateLocalModel, err error) {
respErr := &model.RespErr{}
_, err = common.Request(common.MODELLOCALCREATE, http.MethodPost, func(req *resty.Request) {
req.SetPathParam("username", param.UserName).
SetPathParam("reponame", param.RepoName).
SetFormData(map[string]string{
"name": param.Name,
"version": param.Version,
"engine": strconv.Itoa(param.Engine),
"label": param.Label,
"isPrivate": strconv.FormatBool(param.IsPrivate),
"description": param.Description,
"type": strconv.Itoa(param.Type),
"license": param.License,
common.ACCESSTOKEN: token,
}).SetError(respErr).SetResult(&resp)
})
if err != nil {
return resp, err
}
return resp, nil
}
// ListModel 分页查询模型
func (r ModelService) ListModel(token string, param *model.QueryModelParam) (resp *model.ListModelResp, err error) {
respErr := &model.RespErr{}
_, err = common.Request(common.PageModel, http.MethodGet, func(req *resty.Request) {
req.SetPathParam("username", param.UserName).
SetPathParam("reponame", param.RepoName).
SetQueryParam(common.ACCESSTOKEN, token).
SetResult(&respErr).SetResult(&resp)
})
if err != nil {
return resp, err
}
return resp, nil
}
// getChunks 获取模型该文件已经上传的分片
func (r ModelService) getChunks(token, md5, fileName, dataType, modelUuid, size string) (resp *model.GetChunksResp, err error) {
_, err = common.Request(common.MODELLOCALGETUPLOADEDCHUNKS, http.MethodGet, func(req *resty.Request) {
req.SetQueryParams(map[string]string{
"md5": md5,
"file_name": fileName,
"type": dataType,
"modeluuid": modelUuid,
"size": size,
common.ACCESSTOKEN: token,
}).SetResult(&resp)
})
if err != nil {
return nil, err
}
return
}
// newMultipart 开启一个本地模型文件上传
func (r ModelService) newMultipart(token, totalChunkCounts, md5, fileName, dataType, modelUuid, size string) (resp *model.NewMultipartResp, err error) {
res, err := common.Request(common.MODELLOCALNEWMULTIPART, http.MethodGet, func(req *resty.Request) {
req.SetQueryParams(map[string]string{
"totalChunkCounts": totalChunkCounts,
"type": dataType,
"size": size,
"md5": md5,
"file_name": fileName,
"modeluuid": modelUuid,
common.ACCESSTOKEN: token,
}).SetResult(&resp)
})
if err != nil {
return nil, err
}
if resp.UploadID == "" || resp.Uuid == "" {
msg := json.Get(res, "msg").ToString()
return nil, fmt.Errorf(msg)
}
return
}
// getMultipartUrl 获取模型分片传输url
func (r ModelService) getMultipartUrl(token, uuid, uploadID, fileName, dataType, modelUuid, size, chunkNumber string) (resp *model.GetMultipartUrlResp, err error) {
res, err := common.Request(common.MODELLOCALGETMULTIPARTURL, http.MethodGet, func(req *resty.Request) {
req.SetQueryParams(map[string]string{
"uuid": uuid,
"uploadID": uploadID,
"type": dataType,
"size": size,
"chunkNumber": chunkNumber,
"file_name": fileName,
"modeluuid": modelUuid,
common.ACCESSTOKEN: token,
}).SetResult(&resp)
})
if err != nil {
return nil, err
}
if resp.Url == "" {
msg := json.Get(res, "msg").ToString()
return nil, fmt.Errorf(msg)
}
return
}
// upLoadChunk 上传chunk
func (r ModelService) upLoadChunk(token string, reqUrl, fileName string, reader io.Reader) (err error) {
client := &http.Client{}
req, err := http.NewRequest(http.MethodPut, reqUrl, reader)
if err != nil {
return err
}
req.Header.Set("Content-Type", "")
res, err := client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
// 读取响应体
_, err = io.ReadAll(res.Body)
if err != nil {
return err
}
if res.StatusCode != http.StatusOK {
return errors.New(res.Status)
}
return nil
}
// completeMultipart 完成模型文件上传
func (r ModelService) completeMultipart(token string, uuid, uploadID, fileName, size, modelUuid, dataType string) (resp *model.CompleteMultipartResp, err error) {
_, err = common.Request(common.MODELLOCALCOMPLETEMULTIPART, http.MethodPost, func(req *resty.Request) {
req.SetFormData(map[string]string{
"uuid": uuid,
"uploadID": uploadID,
"type": dataType,
"modeluuid": modelUuid,
"file_name": fileName,
"size": size,
common.ACCESSTOKEN: token,
}).SetResult(&resp)
})
if err != nil {
return nil, err
}
if resp.ResultCode == "-1" {
return nil, fmt.Errorf(resp.Msg)
}
return
}
func (r ModelService) UploadFile(param model.ModelUploadFileParam, token string, fileHeaders []*multipart.FileHeader) (respId string, err error) {
modelUuid := param.ModelUuid
for _, fileHeader := range fileHeaders {
// step.1 优先计算所需信息
dataType := "1"
uuid := ""
uploadID := ""
chunkNumber := 1
fileName := fileHeader.Filename
fileSize := fileHeader.Size
totalChunkCounts := int(math.Ceil(float64(fileSize) / float64(common.MaxChunkSize)))
fileSizeStr := strconv.FormatInt(fileSize, 10)
// 打开上传的文件
file, err := fileHeader.Open()
if err != nil {
return "", errors.New(fmt.Sprintf("文件打开失败: %s", err.Error()))
}
defer file.Close() // 确保关闭文件
md5hash, err := common.GetFileMd5(file)
if err != nil {
return "", err
}
// Get already uploaded chunks
chunks, err := r.getChunks(token, md5hash, fileName, dataType, modelUuid, fileSizeStr)
if err != nil {
return "", err
}
if chunks.Uploaded == "1" {
return "", errors.New(fmt.Sprintf("该文件已上传在模型: %s", chunks.ModelName))
}
if chunks.UploadID != "" && chunks.Uuid != "" {
uuid = chunks.Uuid
uploadID = chunks.UploadID
} else {
// Start a new multipart upload
newMultipart, err := r.newMultipart(token, strconv.Itoa(totalChunkCounts), md5hash, fileName, dataType, modelUuid, fileSizeStr)
if err != nil {
return "", err
}
uuid = newMultipart.Uuid
uploadID = newMultipart.UploadID
}
// Upload each chunk
for chunkNumber <= totalChunkCounts {
// Get multipart URL for the current chunk
multipartUrl, err := r.getMultipartUrl(token, uuid, uploadID, fileName, dataType, modelUuid, fileSizeStr, strconv.Itoa(chunkNumber))
if err != nil {
return "", err
}
// Create a reader for the current chunk
chunkReader := io.NewSectionReader(file, int64(chunkNumber-1)*common.MaxChunkSize, common.MaxChunkSize)
// Retry mechanism for uploading the current chunk
for attempt := 1; attempt <= 3; attempt++ {
err = r.upLoadChunk(token, multipartUrl.Url, fileName, chunkReader)
if err == nil {
break
}
if attempt == 3 {
return "", errors.New(fmt.Sprintf("error uploading chunk %d after 3 attempts: %s", chunkNumber, err.Error()))
}
}
chunkNumber++
}
// Complete the multipart upload
_, err = r.completeMultipart(token, uuid, uploadID, fileName, fileSizeStr, modelUuid, dataType)
if err != nil {
return "", err
}
}
return "", nil
}