233 lines
6.2 KiB
Go
233 lines
6.2 KiB
Go
package rpc
|
||
|
||
import (
|
||
"crypto/tls"
|
||
"fmt"
|
||
"strconv"
|
||
|
||
jcstypes "gitlink.org.cn/cloudream/jcs-pub/common/types"
|
||
"golang.org/x/net/context"
|
||
"google.golang.org/grpc"
|
||
"google.golang.org/grpc/codes"
|
||
"google.golang.org/grpc/credentials"
|
||
"google.golang.org/grpc/metadata"
|
||
"google.golang.org/grpc/peer"
|
||
"google.golang.org/grpc/status"
|
||
)
|
||
|
||
const (
|
||
ClientAPISNIV1 = "rpc.client.jcs-pub.v1"
|
||
InternalAPISNIV1 = "rpc.internal.jcs-pub.v1"
|
||
|
||
MetaUserID = "x-jcs-user-id"
|
||
MetaAccessTokenID = "x-jcs-access-token-id"
|
||
MetaNonce = "x-jcs-nonce"
|
||
MetaSignature = "x-jcs-signature"
|
||
MetaTokenAuthInfo = "x-jcs-token-auth-info"
|
||
)
|
||
|
||
type AccessTokenAuthInfo struct {
|
||
UserID jcstypes.UserID
|
||
AccessTokenID jcstypes.AccessTokenID
|
||
Nonce string
|
||
Signature string
|
||
}
|
||
|
||
type AccessTokenVerifier interface {
|
||
Verify(authInfo AccessTokenAuthInfo) bool
|
||
}
|
||
|
||
type AccessTokenProvider interface {
|
||
MakeAuthInfo() (AccessTokenAuthInfo, error)
|
||
}
|
||
|
||
func (s *ServerBase) tlsConfigSelector(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||
switch hello.ServerName {
|
||
case ClientAPISNIV1:
|
||
return &tls.Config{
|
||
Certificates: []tls.Certificate{s.serverCert},
|
||
ClientAuth: tls.NoClientCert,
|
||
NextProtos: []string{"h2"},
|
||
}, nil
|
||
case InternalAPISNIV1:
|
||
return &tls.Config{
|
||
Certificates: []tls.Certificate{s.serverCert},
|
||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||
ClientCAs: s.rootCA,
|
||
NextProtos: []string{"h2"},
|
||
}, nil
|
||
|
||
default:
|
||
return nil, fmt.Errorf("unknown server name: %s", hello.ServerName)
|
||
}
|
||
}
|
||
|
||
func (s *ServerBase) authUnary(
|
||
ctx context.Context,
|
||
req interface{},
|
||
info *grpc.UnaryServerInfo,
|
||
handler grpc.UnaryHandler,
|
||
|
||
) (resp any, err error) {
|
||
pr, ok := peer.FromContext(ctx)
|
||
if !ok {
|
||
return nil, status.Error(codes.Unauthenticated, "no peer found in context")
|
||
}
|
||
|
||
tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo)
|
||
if !ok {
|
||
return nil, status.Error(codes.Unauthenticated, "no tls info found in peer")
|
||
}
|
||
|
||
// 如果是使用interanl ServerName通过的TLS认证,则直接放行
|
||
if tlsInfo.State.ServerName == InternalAPISNIV1 {
|
||
return handler(ctx, req)
|
||
}
|
||
|
||
// 如果是无需认证的API,则直接放行
|
||
if s.noAuthAPIs[info.FullMethod] {
|
||
return handler(ctx, req)
|
||
}
|
||
|
||
// 否则要进行额外的Token认证
|
||
|
||
if !s.accessTokenAuthAPIs[info.FullMethod] {
|
||
return nil, status.Error(codes.Unauthenticated, "unauthorized access")
|
||
}
|
||
|
||
meta, ok := metadata.FromIncomingContext(ctx)
|
||
if !ok {
|
||
return nil, status.Error(codes.Unauthenticated, "no metadata found in context")
|
||
}
|
||
|
||
userIDs := meta.Get(MetaUserID)
|
||
if len(userIDs) != 1 {
|
||
return nil, status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata")
|
||
}
|
||
userID, err := strconv.ParseInt(userIDs[0], 10, 64)
|
||
if err != nil {
|
||
return nil, status.Error(codes.Unauthenticated, "invalid user id in metadata")
|
||
}
|
||
|
||
accessTokenIDs := meta.Get(MetaAccessTokenID)
|
||
if len(accessTokenIDs) != 1 {
|
||
return nil, status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata")
|
||
}
|
||
|
||
nonce := meta.Get(MetaNonce)
|
||
if len(nonce) != 1 {
|
||
return nil, status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata")
|
||
}
|
||
|
||
signature := meta.Get(MetaSignature)
|
||
if len(signature) != 1 {
|
||
return nil, status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata")
|
||
}
|
||
|
||
authInfo := AccessTokenAuthInfo{
|
||
UserID: jcstypes.UserID(userID),
|
||
AccessTokenID: jcstypes.AccessTokenID(accessTokenIDs[0]),
|
||
Nonce: nonce[0],
|
||
Signature: signature[0],
|
||
}
|
||
if !s.tokenVerifier.Verify(authInfo) {
|
||
return nil, status.Error(codes.Unauthenticated, "invalid access token")
|
||
}
|
||
|
||
ctx = context.WithValue(ctx, MetaTokenAuthInfo, authInfo)
|
||
return handler(ctx, req)
|
||
}
|
||
|
||
func (s *ServerBase) authStream(
|
||
srv any,
|
||
stream grpc.ServerStream,
|
||
info *grpc.StreamServerInfo,
|
||
handler grpc.StreamHandler,
|
||
) error {
|
||
pr, ok := peer.FromContext(stream.Context())
|
||
if !ok {
|
||
return status.Error(codes.Unauthenticated, "no peer found in context")
|
||
}
|
||
|
||
tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo)
|
||
if !ok {
|
||
return status.Error(codes.Unauthenticated, "no tls info found in peer")
|
||
}
|
||
|
||
// 如果是使用interanl ServerName通过的TLS认证,则直接放行
|
||
if tlsInfo.State.ServerName == InternalAPISNIV1 {
|
||
return handler(srv, stream)
|
||
}
|
||
|
||
// 如果是无需认证的API,则直接放行
|
||
if s.noAuthAPIs[info.FullMethod] {
|
||
return handler(srv, stream)
|
||
}
|
||
|
||
// 否则要进行额外的Token认证
|
||
|
||
if !s.accessTokenAuthAPIs[info.FullMethod] {
|
||
return status.Error(codes.Unauthenticated, "unauthorized access")
|
||
}
|
||
|
||
meta, ok := metadata.FromIncomingContext(stream.Context())
|
||
if !ok {
|
||
return status.Error(codes.Unauthenticated, "no metadata found in context")
|
||
}
|
||
|
||
userIDs := meta.Get(MetaUserID)
|
||
if len(userIDs) != 1 {
|
||
return status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata")
|
||
}
|
||
userID, err := strconv.ParseInt(userIDs[0], 10, 64)
|
||
if err != nil {
|
||
return status.Error(codes.Unauthenticated, "invalid user id in metadata")
|
||
}
|
||
|
||
accessTokenIDs := meta.Get(MetaAccessTokenID)
|
||
if len(accessTokenIDs) != 1 {
|
||
return status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata")
|
||
}
|
||
|
||
nonce := meta.Get(MetaNonce)
|
||
if len(nonce) != 1 {
|
||
return status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata")
|
||
}
|
||
|
||
signature := meta.Get(MetaSignature)
|
||
if len(signature) != 1 {
|
||
return status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata")
|
||
}
|
||
|
||
authInfo := AccessTokenAuthInfo{
|
||
UserID: jcstypes.UserID(userID),
|
||
AccessTokenID: jcstypes.AccessTokenID(accessTokenIDs[0]),
|
||
Nonce: nonce[0],
|
||
Signature: signature[0],
|
||
}
|
||
|
||
if !s.tokenVerifier.Verify(authInfo) {
|
||
return status.Error(codes.Unauthenticated, "invalid access token")
|
||
}
|
||
|
||
return handler(srv, &serverStream{stream, context.WithValue(stream.Context(), MetaTokenAuthInfo, authInfo)})
|
||
}
|
||
|
||
type serverStream struct {
|
||
grpc.ServerStream
|
||
ctx context.Context
|
||
}
|
||
|
||
func (s *serverStream) Context() context.Context {
|
||
return s.ctx
|
||
}
|
||
|
||
func GetAuthInfo(ctx context.Context) (AccessTokenAuthInfo, bool) {
|
||
val := ctx.Value(MetaTokenAuthInfo)
|
||
if val == nil {
|
||
return AccessTokenAuthInfo{}, false
|
||
}
|
||
authInfo, ok := val.(AccessTokenAuthInfo)
|
||
return authInfo, ok
|
||
}
|