JCS-pub/common/pkgs/rpc/pool.go

177 lines
3.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package rpc
import (
context "context"
"crypto/tls"
"crypto/x509"
"fmt"
"sync"
"time"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
)
type PoolConfig struct {
RootCA *x509.CertPool
// 客户端证书与AccessTokenProvider二选一
ClientCert *tls.Certificate
// AccessTokenProvider与ClientCert二选一
AccessTokenProvider AccessTokenProvider
}
type ConnPool struct {
cfg PoolConfig
grpcCons map[string]*grpcCon
lock sync.Mutex
}
type grpcCon struct {
grpcCon *grpc.ClientConn
refCount int
stopClosing chan any
}
func NewConnPool(cfg PoolConfig) *ConnPool {
return &ConnPool{
cfg: cfg,
grpcCons: make(map[string]*grpcCon),
}
}
func (p *ConnPool) GetConnection(addr string) (*grpc.ClientConn, error) {
p.lock.Lock()
defer p.lock.Unlock()
con := p.grpcCons[addr]
if con == nil {
gcon, err := p.connecting(addr)
if err != nil {
return nil, err
}
con = &grpcCon{
grpcCon: gcon,
refCount: 0,
stopClosing: nil,
}
p.grpcCons[addr] = con
} else if con.stopClosing != nil {
close(con.stopClosing)
con.stopClosing = nil
}
con.refCount++
return con.grpcCon, nil
}
func (p *ConnPool) connecting(addr string) (*grpc.ClientConn, error) {
if p.cfg.ClientCert != nil {
gcon, err := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: p.cfg.RootCA,
Certificates: []tls.Certificate{*p.cfg.ClientCert},
ServerName: InternalAPISNIV1,
NextProtos: []string{"h2"},
})))
if err != nil {
return nil, err
}
return gcon, nil
}
if p.cfg.AccessTokenProvider == nil {
return nil, fmt.Errorf("no client cert or access token provider")
}
gcon, err := grpc.NewClient(addr,
grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: p.cfg.RootCA,
ServerName: ClientAPISNIV1,
NextProtos: []string{"h2"},
})),
grpc.WithUnaryInterceptor(p.populateAccessTokenUnary),
grpc.WithStreamInterceptor(p.populateAccessTokenStream),
)
if err != nil {
return nil, err
}
return gcon, nil
}
func (p *ConnPool) Release(addr string) {
p.lock.Lock()
defer p.lock.Unlock()
grpcCon := p.grpcCons[addr]
if grpcCon == nil {
return
}
grpcCon.refCount--
grpcCon.refCount = max(grpcCon.refCount, 0)
if grpcCon.refCount == 0 {
stopClosing := make(chan any)
grpcCon.stopClosing = stopClosing
go func() {
select {
case <-stopClosing:
return
case <-time.After(time.Minute):
p.lock.Lock()
defer p.lock.Unlock()
grpcCon := p.grpcCons[addr]
if grpcCon == nil {
return
}
if grpcCon.refCount == 0 {
grpcCon.grpcCon.Close()
delete(p.grpcCons, addr)
}
}
}()
}
}
func (p *ConnPool) populateAccessTokenUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
authInfo, err := p.cfg.AccessTokenProvider.MakeAuthInfo()
if err != nil {
return err
}
md := metadata.Pairs(
MetaUserID, fmt.Sprintf("%v", authInfo.UserID),
MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID),
MetaNonce, authInfo.Nonce,
MetaSignature, authInfo.Signature,
)
ctx = metadata.NewOutgoingContext(ctx, md)
return invoker(ctx, method, req, reply, cc, opts...)
}
func (p *ConnPool) populateAccessTokenStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
authInfo, err := p.cfg.AccessTokenProvider.MakeAuthInfo()
if err != nil {
return nil, err
}
md := metadata.Pairs(
MetaUserID, fmt.Sprintf("%v", authInfo.UserID),
MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID),
MetaNonce, authInfo.Nonce,
MetaSignature, authInfo.Signature,
)
ctx = metadata.NewOutgoingContext(ctx, md)
return streamer(ctx, desc, cc, method, opts...)
}