JCS-pub/common/pkgs/rpc/server.go

118 lines
2.7 KiB
Go

package rpc
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"gitlink.org.cn/cloudream/common/pkgs/async"
"gitlink.org.cn/cloudream/common/pkgs/logger"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
type ServerEventChan = async.UnboundChannel[RPCServerEvent]
type RPCServerEvent interface {
IsRPCServerEvent()
}
type ExitEvent struct {
RPCServerEvent
Err error
}
type Config struct {
Listen string `json:"listen"`
RootCA string `json:"rootCA"`
ServerCert string `json:"serverCert"`
ServerKey string `json:"serverKey"`
}
type ServerBase struct {
cfg Config
grpcSvr *grpc.Server
srvImpl any
svcDesc *grpc.ServiceDesc
rootCA *x509.CertPool
serverCert tls.Certificate
accessTokenAuthAPIs map[string]bool
tokenVerifier AccessTokenVerifier
noAuthAPIs map[string]bool
}
func NewServerBase(cfg Config, srvImpl any, svcDesc *grpc.ServiceDesc, accessTokenAuthAPIs []string, tokenVerifier AccessTokenVerifier, noAuthAPIs []string) *ServerBase {
tokenAuthAPIs := make(map[string]bool)
for _, api := range accessTokenAuthAPIs {
tokenAuthAPIs[api] = true
}
noAuths := make(map[string]bool)
for _, api := range noAuthAPIs {
noAuths[api] = true
}
return &ServerBase{
cfg: cfg,
srvImpl: srvImpl,
svcDesc: svcDesc,
accessTokenAuthAPIs: tokenAuthAPIs,
tokenVerifier: tokenVerifier,
noAuthAPIs: noAuths,
}
}
func (s *ServerBase) Start() *ServerEventChan {
ch := async.NewUnboundChannel[RPCServerEvent]()
go func() {
svrCert, err := tls.LoadX509KeyPair(s.cfg.ServerCert, s.cfg.ServerKey)
if err != nil {
logger.Warnf("load server cert: %v", err)
ch.Send(&ExitEvent{Err: err})
return
}
s.serverCert = svrCert
rootCA, err := os.ReadFile(s.cfg.RootCA)
if err != nil {
logger.Warnf("load root ca: %v", err)
ch.Send(&ExitEvent{Err: err})
return
}
s.rootCA = x509.NewCertPool()
if !s.rootCA.AppendCertsFromPEM(rootCA) {
logger.Warnf("load root ca: failed to parse root ca")
ch.Send(&ExitEvent{Err: fmt.Errorf("failed to parse root ca")})
return
}
logger.Infof("start serving rpc at: %v", s.cfg.Listen)
lis, err := net.Listen("tcp", s.cfg.Listen)
if err != nil {
ch.Send(&ExitEvent{Err: err})
return
}
s.grpcSvr = grpc.NewServer(
grpc.Creds(credentials.NewTLS(&tls.Config{
GetConfigForClient: s.tlsConfigSelector,
})),
grpc.UnaryInterceptor(s.authUnary),
grpc.StreamInterceptor(s.authStream),
)
s.grpcSvr.RegisterService(s.svcDesc, s.srvImpl)
err = s.grpcSvr.Serve(lis)
ch.Send(&ExitEvent{Err: err})
}()
return ch
}
func (s *ServerBase) Stop() {
if s.grpcSvr != nil {
s.grpcSvr.Stop()
}
}