JCS-pub/common/pkgs/rpc/utils.go

295 lines
7.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package rpc
import (
"fmt"
"io"
"gitlink.org.cn/cloudream/common/consts/errorcode"
"gitlink.org.cn/cloudream/common/utils/io2"
"gitlink.org.cn/cloudream/common/utils/serder"
"gitlink.org.cn/cloudream/jcs-pub/common/ecode"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func UnaryClient[Resp, Req any](apiFn func(context.Context, *Request, ...grpc.CallOption) (*Response, error), ctx context.Context, req Req) (Resp, *CodeError) {
data, err := serder.ObjectToJSONEx(req)
if err != nil {
var resp Resp
return resp, Failed(errorcode.OperationFailed, err.Error())
}
resp, err := apiFn(ctx, &Request{
Payload: data,
})
if err != nil {
var resp Resp
return resp, getCodeError(err)
}
ret, err := serder.JSONToObjectEx[Resp](resp.Payload)
if err != nil {
return ret, Failed(errorcode.OperationFailed, err.Error())
}
return ret, nil
}
func UnaryServer[Resp, Req any](apiFn func(context.Context, Req) (Resp, *CodeError), ctx context.Context, req *Request) (*Response, error) {
rreq, err := serder.JSONToObjectEx[Req](req.Payload)
if err != nil {
return nil, MakeCodeError(errorcode.OperationFailed, err.Error())
}
ret, cerr := apiFn(ctx, rreq)
if cerr != nil {
return nil, WrapCodeError(cerr)
}
data, err := serder.ObjectToJSONEx(ret)
if err != nil {
return nil, MakeCodeError(errorcode.OperationFailed, err.Error())
}
return &Response{
Payload: data,
}, nil
}
type UploadStreamAPIClient interface {
GRPCChunkedWriter
CloseAndRecv() (*Response, error)
}
type UploadStreamAPIServer interface {
GRPCChunkedReader
SendAndClose(*Response) error
Context() context.Context
}
type UploadStreamReq interface {
GetStream() io.Reader
SetStream(io.Reader)
}
// 封装了上传流API的客户端逻辑。记得将Req里的Stream字段设置为不需要序列化json:"-"
func UploadStreamClient[Resp any, Req UploadStreamReq, APIRet UploadStreamAPIClient](apiFn func(context.Context, ...grpc.CallOption) (APIRet, error), ctx context.Context, req Req) (Resp, *CodeError) {
stream := req.GetStream()
var ret Resp
data, err := serder.ObjectToJSONEx(req)
if err != nil {
return ret, Failed(errorcode.OperationFailed, err.Error())
}
ctx2, cancelFn := context.WithCancel(ctx)
defer cancelFn()
cli, err := apiFn(ctx2)
if err != nil {
return ret, getCodeError(err)
}
cw := NewChunkedWriter(cli)
err = cw.WriteDataPart("", data)
if err != nil {
return ret, Failed(errorcode.OperationFailed, err.Error())
}
_, err = cw.WriteStreamPart("", stream)
if err != nil {
return ret, Failed(errorcode.OperationFailed, err.Error())
}
err = cw.Finish()
if err != nil {
return ret, Failed(errorcode.OperationFailed, err.Error())
}
resp, err := cli.CloseAndRecv()
if err != nil {
return ret, Failed(errorcode.OperationFailed, err.Error())
}
ret, err = serder.JSONToObjectEx[Resp](resp.Payload)
if err != nil {
return ret, Failed(errorcode.OperationFailed, err.Error())
}
return ret, nil
}
func UploadStreamServer[Resp any, Req UploadStreamReq, APIRet UploadStreamAPIServer](apiFn func(context.Context, Req) (Resp, *CodeError), req APIRet) error {
cr := NewChunkedReader(req)
_, data, err := cr.NextDataPart()
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
_, pr, err := cr.NextPart()
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
rreq, err := serder.JSONToObjectEx[Req](data)
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
rreq.SetStream(pr)
resp, cerr := apiFn(req.Context(), rreq)
if cerr != nil {
return WrapCodeError(cerr)
}
respData, err := serder.ObjectToJSONEx(resp)
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
err = req.SendAndClose(&Response{Payload: respData})
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
return nil
}
type DownloadStreamAPIClient interface {
GRPCChunkedReader
}
type DownloadStreamAPIServer interface {
GRPCChunkedWriter
Context() context.Context
}
type DownloadStreamResp interface {
GetStream() io.ReadCloser
SetStream(io.ReadCloser)
}
// 封装了下载流API的客户端逻辑。记得将Resp里的Stream字段设置为不需要序列化json:"-"
func DownloadStreamClient[Resp DownloadStreamResp, Req any, APIRet DownloadStreamAPIClient](apiFn func(context.Context, *Request, ...grpc.CallOption) (APIRet, error), ctx context.Context, req Req) (Resp, *CodeError) {
var ret Resp
data, err := serder.ObjectToJSONEx(req)
if err != nil {
return ret, Failed(errorcode.OperationFailed, err.Error())
}
ctx2, cancelFn := context.WithCancel(ctx)
cli, err := apiFn(ctx2, &Request{Payload: data})
if err != nil {
cancelFn()
return ret, getCodeError(err)
}
cr := NewChunkedReader(cli)
_, data, err = cr.NextDataPart()
if err != nil {
cancelFn()
return ret, Failed(errorcode.OperationFailed, err.Error())
}
resp, err := serder.JSONToObjectEx[Resp](data)
if err != nil {
cancelFn()
return ret, Failed(errorcode.OperationFailed, err.Error())
}
_, pr, err := cr.NextPart()
if err != nil {
cancelFn()
return ret, Failed(errorcode.OperationFailed, err.Error())
}
resp.SetStream(io2.DelegateReadCloser(pr, func() error {
cancelFn()
return nil
}))
return resp, nil
}
func DownloadStreamServer[Resp DownloadStreamResp, Req any, APIRet DownloadStreamAPIServer](apiFn func(context.Context, Req) (Resp, *CodeError), req *Request, ret APIRet) error {
rreq, err := serder.JSONToObjectEx[Req](req.Payload)
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
resp, cerr := apiFn(ret.Context(), rreq)
if cerr != nil {
return WrapCodeError(cerr)
}
cw := NewChunkedWriter(ret)
data, err := serder.ObjectToJSONEx(resp)
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
err = cw.WriteDataPart("", data)
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
_, err = cw.WriteStreamPart("", resp.GetStream())
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
err = cw.Finish()
if err != nil {
return MakeCodeError(errorcode.OperationFailed, err.Error())
}
return nil
}
func Failed(errCode ecode.ErrorCode, format string, args ...any) *CodeError {
return &CodeError{
Code: string(errCode),
Message: fmt.Sprintf(format, args...),
}
}
// 定义一个额外的结构体,防止陷入 (*CodeError)(nil) != nil 的陷阱
type ErrorCodeError struct {
CE *CodeError
}
func (c *ErrorCodeError) Error() string {
return fmt.Sprintf("code: %s, message: %s", c.CE.Code, c.CE.Message)
}
func (c *CodeError) ToError() error {
if c == nil {
return nil
}
return &ErrorCodeError{CE: c}
}
func getCodeError(err error) *CodeError {
status, ok := status.FromError(err)
if ok {
dts := status.Details()
if len(dts) > 0 {
ce, ok := dts[0].(*CodeError)
if ok {
return ce
}
}
}
return Failed(errorcode.OperationFailed, err.Error())
}
func MakeCodeError(code ecode.ErrorCode, msg string) error {
ce, _ := status.New(codes.Unknown, "custom error").WithDetails(Failed(code, msg))
return ce.Err()
}
func WrapCodeError(ce *CodeError) error {
e, _ := status.New(codes.Unknown, "custom error").WithDetails(ce)
return e.Err()
}