118 lines
2.7 KiB
Go
118 lines
2.7 KiB
Go
package rpc
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
|
|
"gitlink.org.cn/cloudream/common/pkgs/async"
|
|
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
)
|
|
|
|
type ServerEventChan = async.UnboundChannel[RPCServerEvent]
|
|
|
|
type RPCServerEvent interface {
|
|
IsRPCServerEvent()
|
|
}
|
|
|
|
type ExitEvent struct {
|
|
RPCServerEvent
|
|
Err error
|
|
}
|
|
|
|
type Config struct {
|
|
Listen string `json:"listen"`
|
|
RootCA string `json:"rootCA"`
|
|
ServerCert string `json:"serverCert"`
|
|
ServerKey string `json:"serverKey"`
|
|
}
|
|
|
|
type ServerBase struct {
|
|
cfg Config
|
|
grpcSvr *grpc.Server
|
|
srvImpl any
|
|
svcDesc *grpc.ServiceDesc
|
|
rootCA *x509.CertPool
|
|
serverCert tls.Certificate
|
|
accessTokenAuthAPIs map[string]bool
|
|
tokenVerifier AccessTokenVerifier
|
|
noAuthAPIs map[string]bool
|
|
}
|
|
|
|
func NewServerBase(cfg Config, srvImpl any, svcDesc *grpc.ServiceDesc, accessTokenAuthAPIs []string, tokenVerifier AccessTokenVerifier, noAuthAPIs []string) *ServerBase {
|
|
tokenAuthAPIs := make(map[string]bool)
|
|
for _, api := range accessTokenAuthAPIs {
|
|
tokenAuthAPIs[api] = true
|
|
}
|
|
|
|
noAuths := make(map[string]bool)
|
|
for _, api := range noAuthAPIs {
|
|
noAuths[api] = true
|
|
}
|
|
|
|
return &ServerBase{
|
|
cfg: cfg,
|
|
srvImpl: srvImpl,
|
|
svcDesc: svcDesc,
|
|
accessTokenAuthAPIs: tokenAuthAPIs,
|
|
tokenVerifier: tokenVerifier,
|
|
noAuthAPIs: noAuths,
|
|
}
|
|
}
|
|
|
|
func (s *ServerBase) Start() *ServerEventChan {
|
|
ch := async.NewUnboundChannel[RPCServerEvent]()
|
|
go func() {
|
|
svrCert, err := tls.LoadX509KeyPair(s.cfg.ServerCert, s.cfg.ServerKey)
|
|
if err != nil {
|
|
logger.Warnf("load server cert: %v", err)
|
|
ch.Send(&ExitEvent{Err: err})
|
|
return
|
|
}
|
|
s.serverCert = svrCert
|
|
|
|
rootCA, err := os.ReadFile(s.cfg.RootCA)
|
|
if err != nil {
|
|
logger.Warnf("load root ca: %v", err)
|
|
ch.Send(&ExitEvent{Err: err})
|
|
return
|
|
}
|
|
|
|
s.rootCA = x509.NewCertPool()
|
|
if !s.rootCA.AppendCertsFromPEM(rootCA) {
|
|
logger.Warnf("load root ca: failed to parse root ca")
|
|
ch.Send(&ExitEvent{Err: fmt.Errorf("failed to parse root ca")})
|
|
return
|
|
}
|
|
|
|
logger.Infof("start serving rpc at: %v", s.cfg.Listen)
|
|
lis, err := net.Listen("tcp", s.cfg.Listen)
|
|
if err != nil {
|
|
ch.Send(&ExitEvent{Err: err})
|
|
return
|
|
}
|
|
|
|
s.grpcSvr = grpc.NewServer(
|
|
grpc.Creds(credentials.NewTLS(&tls.Config{
|
|
GetConfigForClient: s.tlsConfigSelector,
|
|
})),
|
|
grpc.UnaryInterceptor(s.authUnary),
|
|
grpc.StreamInterceptor(s.authStream),
|
|
)
|
|
s.grpcSvr.RegisterService(s.svcDesc, s.srvImpl)
|
|
err = s.grpcSvr.Serve(lis)
|
|
ch.Send(&ExitEvent{Err: err})
|
|
}()
|
|
return ch
|
|
}
|
|
|
|
func (s *ServerBase) Stop() {
|
|
if s.grpcSvr != nil {
|
|
s.grpcSvr.Stop()
|
|
}
|
|
}
|