252 lines
7.5 KiB
Go
252 lines
7.5 KiB
Go
package service
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/go-resty/resty/v2"
|
|
json "github.com/json-iterator/go"
|
|
"gitlink.org.cn/JointCloud/pcm-openi/common"
|
|
"gitlink.org.cn/JointCloud/pcm-openi/model"
|
|
"io"
|
|
"math"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"strconv"
|
|
)
|
|
|
|
type ModelService struct {
|
|
}
|
|
|
|
func NewModelService() *ModelService {
|
|
return &ModelService{}
|
|
}
|
|
|
|
func (r ModelService) CreateLocalModel(token string, param model.CreateLocalModelParam) (resp model.CreateLocalModel, err error) {
|
|
respErr := &model.RespErr{}
|
|
_, err = common.Request(common.MODELLOCALCREATE, http.MethodPost, func(req *resty.Request) {
|
|
req.SetPathParam("username", param.UserName).
|
|
SetPathParam("reponame", param.RepoName).
|
|
SetFormData(map[string]string{
|
|
"name": param.Name,
|
|
"version": param.Version,
|
|
"engine": strconv.Itoa(param.Engine),
|
|
"label": param.Label,
|
|
"isPrivate": strconv.FormatBool(param.IsPrivate),
|
|
"description": param.Description,
|
|
"type": strconv.Itoa(param.Type),
|
|
"license": param.License,
|
|
common.ACCESSTOKEN: token,
|
|
}).SetError(respErr).SetResult(&resp)
|
|
})
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// ListModel 分页查询模型
|
|
func (r ModelService) ListModel(token string, param *model.QueryModelParam) (resp *model.ListModelResp, err error) {
|
|
respErr := &model.RespErr{}
|
|
_, err = common.Request(common.PageModel, http.MethodGet, func(req *resty.Request) {
|
|
req.SetPathParam("username", param.UserName).
|
|
SetPathParam("reponame", param.RepoName).
|
|
SetQueryParam(common.ACCESSTOKEN, token).
|
|
SetResult(&respErr).SetResult(&resp)
|
|
})
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// getChunks 获取模型该文件已经上传的分片
|
|
func (r ModelService) getChunks(token, md5, fileName, dataType, modelUuid, size string) (resp *model.GetChunksResp, err error) {
|
|
_, err = common.Request(common.MODELLOCALGETUPLOADEDCHUNKS, http.MethodGet, func(req *resty.Request) {
|
|
req.SetQueryParams(map[string]string{
|
|
"md5": md5,
|
|
"file_name": fileName,
|
|
"type": dataType,
|
|
"modeluuid": modelUuid,
|
|
"size": size,
|
|
common.ACCESSTOKEN: token,
|
|
}).SetResult(&resp)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return
|
|
}
|
|
|
|
// newMultipart 开启一个本地模型文件上传
|
|
func (r ModelService) newMultipart(token, totalChunkCounts, md5, fileName, dataType, modelUuid, size string) (resp *model.NewMultipartResp, err error) {
|
|
res, err := common.Request(common.MODELLOCALNEWMULTIPART, http.MethodGet, func(req *resty.Request) {
|
|
req.SetQueryParams(map[string]string{
|
|
"totalChunkCounts": totalChunkCounts,
|
|
"type": dataType,
|
|
"size": size,
|
|
"md5": md5,
|
|
"file_name": fileName,
|
|
"modeluuid": modelUuid,
|
|
common.ACCESSTOKEN: token,
|
|
}).SetResult(&resp)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.UploadID == "" || resp.Uuid == "" {
|
|
msg := json.Get(res, "msg").ToString()
|
|
return nil, fmt.Errorf(msg)
|
|
}
|
|
return
|
|
}
|
|
|
|
// getMultipartUrl 获取模型分片传输url
|
|
func (r ModelService) getMultipartUrl(token, uuid, uploadID, fileName, dataType, modelUuid, size, chunkNumber string) (resp *model.GetMultipartUrlResp, err error) {
|
|
res, err := common.Request(common.MODELLOCALGETMULTIPARTURL, http.MethodGet, func(req *resty.Request) {
|
|
req.SetQueryParams(map[string]string{
|
|
"uuid": uuid,
|
|
"uploadID": uploadID,
|
|
"type": dataType,
|
|
"size": size,
|
|
"chunkNumber": chunkNumber,
|
|
"file_name": fileName,
|
|
"modeluuid": modelUuid,
|
|
common.ACCESSTOKEN: token,
|
|
}).SetResult(&resp)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.Url == "" {
|
|
msg := json.Get(res, "msg").ToString()
|
|
return nil, fmt.Errorf(msg)
|
|
}
|
|
return
|
|
}
|
|
|
|
// upLoadChunk 上传chunk
|
|
func (r ModelService) upLoadChunk(token string, reqUrl, fileName string, reader io.Reader) (err error) {
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest(http.MethodPut, reqUrl, reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", "")
|
|
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
// 读取响应体
|
|
_, err = io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if res.StatusCode != http.StatusOK {
|
|
return errors.New(res.Status)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// completeMultipart 完成模型文件上传
|
|
func (r ModelService) completeMultipart(token string, uuid, uploadID, fileName, size, modelUuid, dataType string) (resp *model.CompleteMultipartResp, err error) {
|
|
_, err = common.Request(common.MODELLOCALCOMPLETEMULTIPART, http.MethodPost, func(req *resty.Request) {
|
|
req.SetFormData(map[string]string{
|
|
"uuid": uuid,
|
|
"uploadID": uploadID,
|
|
"type": dataType,
|
|
"modeluuid": modelUuid,
|
|
"file_name": fileName,
|
|
"size": size,
|
|
common.ACCESSTOKEN: token,
|
|
}).SetResult(&resp)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.ResultCode == "-1" {
|
|
return nil, fmt.Errorf(resp.Msg)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (r ModelService) UploadFile(param model.ModelUploadFileParam, token string, fileHeaders []*multipart.FileHeader) (respId string, err error) {
|
|
modelUuid := param.ModelUuid
|
|
for _, fileHeader := range fileHeaders {
|
|
// step.1 优先计算所需信息
|
|
dataType := "1"
|
|
uuid := ""
|
|
uploadID := ""
|
|
chunkNumber := 1
|
|
fileName := fileHeader.Filename
|
|
fileSize := fileHeader.Size
|
|
totalChunkCounts := int(math.Ceil(float64(fileSize) / float64(common.MaxChunkSize)))
|
|
|
|
fileSizeStr := strconv.FormatInt(fileSize, 10)
|
|
// 打开上传的文件
|
|
file, err := fileHeader.Open()
|
|
if err != nil {
|
|
return "", errors.New(fmt.Sprintf("文件打开失败: %s", err.Error()))
|
|
}
|
|
defer file.Close() // 确保关闭文件
|
|
md5hash, err := common.GetFileMd5(file)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// Get already uploaded chunks
|
|
chunks, err := r.getChunks(token, md5hash, fileName, dataType, modelUuid, fileSizeStr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if chunks.Uploaded == "1" {
|
|
return "", errors.New(fmt.Sprintf("该文件已上传在模型: %s", chunks.ModelName))
|
|
}
|
|
if chunks.UploadID != "" && chunks.Uuid != "" {
|
|
uuid = chunks.Uuid
|
|
uploadID = chunks.UploadID
|
|
} else {
|
|
// Start a new multipart upload
|
|
newMultipart, err := r.newMultipart(token, strconv.Itoa(totalChunkCounts), md5hash, fileName, dataType, modelUuid, fileSizeStr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
uuid = newMultipart.Uuid
|
|
uploadID = newMultipart.UploadID
|
|
}
|
|
|
|
// Upload each chunk
|
|
for chunkNumber <= totalChunkCounts {
|
|
// Get multipart URL for the current chunk
|
|
multipartUrl, err := r.getMultipartUrl(token, uuid, uploadID, fileName, dataType, modelUuid, fileSizeStr, strconv.Itoa(chunkNumber))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Create a reader for the current chunk
|
|
chunkReader := io.NewSectionReader(file, int64(chunkNumber-1)*common.MaxChunkSize, common.MaxChunkSize)
|
|
|
|
// Retry mechanism for uploading the current chunk
|
|
for attempt := 1; attempt <= 3; attempt++ {
|
|
err = r.upLoadChunk(token, multipartUrl.Url, fileName, chunkReader)
|
|
if err == nil {
|
|
break
|
|
}
|
|
if attempt == 3 {
|
|
return "", errors.New(fmt.Sprintf("error uploading chunk %d after 3 attempts: %s", chunkNumber, err.Error()))
|
|
}
|
|
}
|
|
chunkNumber++
|
|
}
|
|
|
|
// Complete the multipart upload
|
|
_, err = r.completeMultipart(token, uuid, uploadID, fileName, fileSizeStr, modelUuid, dataType)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
return "", nil
|
|
}
|