177 lines
3.9 KiB
Go
177 lines
3.9 KiB
Go
package rpc
|
||
|
||
import (
|
||
context "context"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"fmt"
|
||
"sync"
|
||
"time"
|
||
|
||
grpc "google.golang.org/grpc"
|
||
"google.golang.org/grpc/credentials"
|
||
"google.golang.org/grpc/metadata"
|
||
)
|
||
|
||
type PoolConfig struct {
|
||
RootCA *x509.CertPool
|
||
// 客户端证书,与AccessTokenProvider二选一
|
||
ClientCert *tls.Certificate
|
||
// AccessTokenProvider,与ClientCert二选一
|
||
AccessTokenProvider AccessTokenProvider
|
||
}
|
||
|
||
type ConnPool struct {
|
||
cfg PoolConfig
|
||
grpcCons map[string]*grpcCon
|
||
lock sync.Mutex
|
||
}
|
||
|
||
type grpcCon struct {
|
||
grpcCon *grpc.ClientConn
|
||
refCount int
|
||
stopClosing chan any
|
||
}
|
||
|
||
func NewConnPool(cfg PoolConfig) *ConnPool {
|
||
return &ConnPool{
|
||
cfg: cfg,
|
||
grpcCons: make(map[string]*grpcCon),
|
||
}
|
||
}
|
||
|
||
func (p *ConnPool) GetConnection(addr string) (*grpc.ClientConn, error) {
|
||
p.lock.Lock()
|
||
defer p.lock.Unlock()
|
||
|
||
con := p.grpcCons[addr]
|
||
if con == nil {
|
||
gcon, err := p.connecting(addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
con = &grpcCon{
|
||
grpcCon: gcon,
|
||
refCount: 0,
|
||
stopClosing: nil,
|
||
}
|
||
|
||
p.grpcCons[addr] = con
|
||
} else if con.stopClosing != nil {
|
||
close(con.stopClosing)
|
||
con.stopClosing = nil
|
||
}
|
||
|
||
con.refCount++
|
||
|
||
return con.grpcCon, nil
|
||
}
|
||
|
||
func (p *ConnPool) connecting(addr string) (*grpc.ClientConn, error) {
|
||
if p.cfg.ClientCert != nil {
|
||
gcon, err := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||
RootCAs: p.cfg.RootCA,
|
||
Certificates: []tls.Certificate{*p.cfg.ClientCert},
|
||
ServerName: InternalAPISNIV1,
|
||
NextProtos: []string{"h2"},
|
||
})))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return gcon, nil
|
||
}
|
||
|
||
if p.cfg.AccessTokenProvider == nil {
|
||
return nil, fmt.Errorf("no client cert or access token provider")
|
||
}
|
||
|
||
gcon, err := grpc.NewClient(addr,
|
||
grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||
RootCAs: p.cfg.RootCA,
|
||
ServerName: ClientAPISNIV1,
|
||
NextProtos: []string{"h2"},
|
||
})),
|
||
grpc.WithUnaryInterceptor(p.populateAccessTokenUnary),
|
||
grpc.WithStreamInterceptor(p.populateAccessTokenStream),
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return gcon, nil
|
||
}
|
||
func (p *ConnPool) Release(addr string) {
|
||
p.lock.Lock()
|
||
defer p.lock.Unlock()
|
||
|
||
grpcCon := p.grpcCons[addr]
|
||
if grpcCon == nil {
|
||
return
|
||
}
|
||
|
||
grpcCon.refCount--
|
||
grpcCon.refCount = max(grpcCon.refCount, 0)
|
||
|
||
if grpcCon.refCount == 0 {
|
||
stopClosing := make(chan any)
|
||
grpcCon.stopClosing = stopClosing
|
||
|
||
go func() {
|
||
select {
|
||
case <-stopClosing:
|
||
return
|
||
|
||
case <-time.After(time.Minute):
|
||
p.lock.Lock()
|
||
defer p.lock.Unlock()
|
||
|
||
grpcCon := p.grpcCons[addr]
|
||
if grpcCon == nil {
|
||
return
|
||
}
|
||
|
||
if grpcCon.refCount == 0 {
|
||
grpcCon.grpcCon.Close()
|
||
delete(p.grpcCons, addr)
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
}
|
||
|
||
func (p *ConnPool) populateAccessTokenUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||
authInfo, err := p.cfg.AccessTokenProvider.MakeAuthInfo()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
md := metadata.Pairs(
|
||
MetaUserID, fmt.Sprintf("%v", authInfo.UserID),
|
||
MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID),
|
||
MetaNonce, authInfo.Nonce,
|
||
MetaSignature, authInfo.Signature,
|
||
)
|
||
|
||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||
return invoker(ctx, method, req, reply, cc, opts...)
|
||
}
|
||
|
||
func (p *ConnPool) populateAccessTokenStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||
authInfo, err := p.cfg.AccessTokenProvider.MakeAuthInfo()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
md := metadata.Pairs(
|
||
MetaUserID, fmt.Sprintf("%v", authInfo.UserID),
|
||
MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID),
|
||
MetaNonce, authInfo.Nonce,
|
||
MetaSignature, authInfo.Signature,
|
||
)
|
||
|
||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||
return streamer(ctx, desc, cc, method, opts...)
|
||
}
|