From 3d2499ebadf57e3499be9a1514fea1ac60f18a2a Mon Sep 17 00:00:00 2001 From: Sydonian <794346190@qq.com> Date: Fri, 30 May 2025 10:51:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=9C=8D=E5=8A=A1=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/internal/accesstoken/accesstoken.go | 175 +++++++++++++ client/internal/accesstoken/config.go | 6 + client/internal/cmdline/serve.go | 57 ++++- client/internal/cmdline/test.go | 57 ++++- client/internal/cmdline/vfstest.go | 56 ++++- client/internal/config/config.go | 6 +- common/assets/confs/client.config.json | 15 +- common/assets/confs/coordinator.config.json | 10 +- common/assets/confs/datamap.config.json | 24 -- common/assets/confs/hub.config.json | 15 +- common/pkgs/accesstoken/accesstoken.go | 233 ++++++++++++++++++ common/pkgs/rpc/auth.go | 232 +++++++++++++++++ common/pkgs/rpc/coordinator/client.go | 10 +- common/pkgs/rpc/coordinator/coordinator.pb.go | 46 +++- common/pkgs/rpc/coordinator/coordinator.proto | 5 + .../rpc/coordinator/coordinator_grpc.pb.go | 148 +++++++++++ common/pkgs/rpc/coordinator/pool.go | 163 ++++++------ common/pkgs/rpc/coordinator/server.go | 19 +- common/pkgs/rpc/coordinator/storage.go | 2 + common/pkgs/rpc/coordinator/user.go | 95 +++++++ common/pkgs/rpc/hub/client.go | 4 +- common/pkgs/rpc/hub/hub.pb.go | 38 +-- common/pkgs/rpc/hub/hub.proto | 2 + common/pkgs/rpc/hub/hub_grpc.pb.go | 51 +++- common/pkgs/rpc/hub/ioswitch.go | 10 + common/pkgs/rpc/hub/mics.go | 2 + common/pkgs/rpc/hub/pool.go | 139 ++++------- common/pkgs/rpc/hub/server.go | 20 +- common/pkgs/rpc/hub/user.go | 29 +++ common/pkgs/rpc/pool.go | 176 +++++++++++++ common/pkgs/rpc/rpc.go | 1 + common/pkgs/rpc/server.go | 75 +++++- common/pkgs/rpc/utils.go | 34 +-- .../internal/accesstoken/accesstoken.go | 38 +++ coordinator/internal/cmd/cert.go | 212 ++++++++++++++++ coordinator/internal/cmd/migrate.go | 2 + coordinator/internal/cmd/serve.go | 37 ++- coordinator/internal/config/config.go | 10 +- .../internal/db/loaded_access_token.go | 49 ++++ coordinator/internal/db/user.go | 10 +- coordinator/internal/db/user_access_token.go | 33 +++ coordinator/internal/repl/user.go | 62 +++++ coordinator/internal/rpc/service.go | 9 +- coordinator/internal/rpc/user.go | 227 +++++++++++++++++ .../ticktock/clear_expired_access_token.go | 112 +++++++++ coordinator/internal/ticktock/ticktock.go | 23 +- coordinator/types/types.go | 46 +++- go.mod | 10 +- go.sum | 20 +- hub/internal/accesstoken/accesstoken.go | 49 ++++ hub/internal/cmd/serve.go | 52 +++- hub/internal/config/config.go | 4 +- hub/internal/rpc/rpc.go | 13 +- hub/internal/rpc/user.go | 19 ++ 54 files changed, 2661 insertions(+), 331 deletions(-) create mode 100644 client/internal/accesstoken/accesstoken.go create mode 100644 client/internal/accesstoken/config.go delete mode 100644 common/assets/confs/datamap.config.json create mode 100644 common/pkgs/accesstoken/accesstoken.go create mode 100644 common/pkgs/rpc/auth.go create mode 100644 common/pkgs/rpc/coordinator/user.go create mode 100644 common/pkgs/rpc/hub/user.go create mode 100644 common/pkgs/rpc/pool.go create mode 100644 common/pkgs/rpc/rpc.go create mode 100644 coordinator/internal/accesstoken/accesstoken.go create mode 100644 coordinator/internal/cmd/cert.go create mode 100644 coordinator/internal/db/loaded_access_token.go create mode 100644 coordinator/internal/db/user_access_token.go create mode 100644 coordinator/internal/repl/user.go create mode 100644 coordinator/internal/rpc/user.go create mode 100644 coordinator/internal/ticktock/clear_expired_access_token.go create mode 100644 hub/internal/accesstoken/accesstoken.go create mode 100644 hub/internal/rpc/user.go diff --git a/client/internal/accesstoken/accesstoken.go b/client/internal/accesstoken/accesstoken.go new file mode 100644 index 0000000..832e1af --- /dev/null +++ b/client/internal/accesstoken/accesstoken.go @@ -0,0 +1,175 @@ +package accesstoken + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "fmt" + "sync" + "time" + + "gitlink.org.cn/cloudream/common/pkgs/async" + "gitlink.org.cn/cloudream/common/pkgs/logger" + stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" + corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" +) + +type KeeperEvent interface { + IsAccessTokenKeeper() bool +} + +type ExitEvent struct { + KeeperEvent + Err error +} + +type Keeper struct { + cfg Config + enabled bool + token cortypes.UserAccessToken + priKey ed25519.PrivateKey + lock sync.RWMutex + done chan any +} + +func New(cfg Config, tempCli *corrpc.TempClient) (*Keeper, error) { + loginResp, cerr := tempCli.UserLogin(context.Background(), &corrpc.UserLogin{ + Account: cfg.Account, + Password: cfg.Password, + }) + if cerr != nil { + return nil, fmt.Errorf("login: %w", cerr.ToError()) + } + + priKey, err := hex.DecodeString(loginResp.PrivateKey) + if err != nil { + return nil, fmt.Errorf("decode private key: %w", err) + } + + return &Keeper{ + cfg: cfg, + enabled: true, + token: loginResp.Token, + priKey: priKey, + done: make(chan any, 1), + }, nil +} + +func NewDisabled() *Keeper { + return &Keeper{ + done: make(chan any, 1), + enabled: false, + } +} + +func (k *Keeper) Start() *async.UnboundChannel[KeeperEvent] { + log := logger.WithField("Mod", "Keeper") + + ch := async.NewUnboundChannel[KeeperEvent]() + + go func() { + if !k.enabled { + return + } + + k.lock.RLock() + log.Infof("login success, token expires at %v", k.token.ExpiresAt) + k.lock.RUnlock() + + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + loop: + for { + select { + case <-k.done: + break loop + + case <-ticker.C: + k.lock.RLock() + token := k.token + k.lock.RUnlock() + + // 当前Token已经过期,说明之前的刷新都失败了,打个日志 + if time.Now().After(token.ExpiresAt) { + log.Warnf("token expired at %v !", token.ExpiresAt) + } + + // 在Token到期前5分钟时就要开始刷新Token + + tokenDeadline := token.ExpiresAt.Add(-time.Minute * 5) + if time.Now().Before(tokenDeadline) { + continue + } + + corCli := stgglb.CoordinatorRPCPool.Get() + refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{}) + if cerr != nil { + log.Warnf("refresh token: %v", cerr) + corCli.Release() + continue + } + + priKey, err := hex.DecodeString(refResp.PrivateKey) + if err != nil { + log.Warnf("decode private key: %v", err) + corCli.Release() + continue + } + + log.Infof("refresh token success, new token expires at %v", refResp.Token.ExpiresAt) + + k.lock.Lock() + k.token = refResp.Token + k.priKey = priKey + k.lock.Unlock() + + corCli.Release() + } + } + ch.Send(ExitEvent{}) + }() + + return ch +} + +func (k *Keeper) Stop() { + select { + case k.done <- true: + default: + } +} + +func (k *Keeper) GetAuthInfo() (rpc.AccessTokenAuthInfo, error) { + if !k.enabled { + return rpc.AccessTokenAuthInfo{}, fmt.Errorf("function disabled") + } + + k.lock.RLock() + token := k.token + k.lock.RUnlock() + + bytes := make([]byte, 8) + + _, err := rand.Read(bytes) + if err != nil { + return rpc.AccessTokenAuthInfo{}, fmt.Errorf("generate nonce: %w", err) + } + + nonce := hex.EncodeToString(bytes) + stringToSign := accesstoken.MakeStringToSign(token.UserID, token.TokenID, nonce) + + signBytes := ed25519.Sign(k.priKey, []byte(stringToSign)) + signature := hex.EncodeToString(signBytes) + + return rpc.AccessTokenAuthInfo{ + UserID: token.UserID, + AccessTokenID: token.TokenID, + Nonce: nonce, + Signature: signature, + }, nil +} diff --git a/client/internal/accesstoken/config.go b/client/internal/accesstoken/config.go new file mode 100644 index 0000000..985e7f3 --- /dev/null +++ b/client/internal/accesstoken/config.go @@ -0,0 +1,6 @@ +package accesstoken + +type Config struct { + Account string `json:"account"` + Password string `json:"password"` +} diff --git a/client/internal/cmdline/serve.go b/client/internal/cmdline/serve.go index d433025..6e738a4 100644 --- a/client/internal/cmdline/serve.go +++ b/client/internal/cmdline/serve.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" "gitlink.org.cn/cloudream/common/pkgs/logger" "gitlink.org.cn/cloudream/jcs-pub/client/internal/accessstat" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/client/internal/config" "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" @@ -69,8 +70,42 @@ func serveHTTP(configPath string, opts serveHTTPOptions) { } stgglb.InitLocal(config.Cfg().Local) - stgglb.InitPools(&config.Cfg().HubRPC, &config.Cfg().CoordinatorRPC) - stgglb.StandaloneMode = opts.Standalone + + stgglb.StandaloneMode = opts.Standalone || config.Cfg().AccessToken == nil + + var accToken *accesstoken.Keeper + if !stgglb.StandaloneMode { + tempCli, err := config.Cfg().CoordinatorRPC.BuildTempClient() + if err != nil { + logger.Warnf("build coordinator rpc temp client: %v", err) + os.Exit(1) + } + + accToken, err = accesstoken.New(*config.Cfg().AccessToken, tempCli) + tempCli.Release() + if err != nil { + logger.Warnf("new access token keeper: %v", err) + os.Exit(1) + } + + hubRPCCfg, err := config.Cfg().HubRPC.Build(accToken) + if err != nil { + logger.Warnf("build hub rpc pool config: %v", err) + os.Exit(1) + } + + corRPCCfg, err := config.Cfg().CoordinatorRPC.Build(accToken) + if err != nil { + logger.Warnf("build coordinator rpc pool config: %v", err) + os.Exit(1) + } + + stgglb.InitPools(hubRPCCfg, corRPCCfg) + } else { + accToken = accesstoken.NewDisabled() + } + accTokenChan := accToken.Start() + defer accToken.Stop() // 数据库 db, err := db.NewDB(&config.Cfg().DB) @@ -162,6 +197,7 @@ func serveHTTP(configPath string, opts serveHTTPOptions) { /// 开始监听各个模块的事件 + accTokenEvt := accTokenChan.Receive() evtPubEvt := evtPubChan.Receive() acStatEvt := acStatChan.Receive() replEvt := replCh.Receive() @@ -171,6 +207,23 @@ func serveHTTP(configPath string, opts serveHTTPOptions) { loop: for { select { + case e := <-accTokenEvt.Chan(): + if e.Err != nil { + logger.Errorf("receive access token event: %v", err) + break loop + } + + switch e := e.Value.(type) { + case accesstoken.ExitEvent: + if e.Err != nil { + logger.Errorf("access token keeper exit with error: %v", err) + } else { + logger.Info("access token keeper exited") + } + break loop + } + accTokenEvt = accTokenChan.Receive() + case e := <-evtPubEvt.Chan(): if e.Err != nil { logger.Errorf("receive publisher event: %v", err) diff --git a/client/internal/cmdline/test.go b/client/internal/cmdline/test.go index 2eb95e1..5bd87fd 100644 --- a/client/internal/cmdline/test.go +++ b/client/internal/cmdline/test.go @@ -10,6 +10,7 @@ import ( "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" "gitlink.org.cn/cloudream/common/pkgs/logger" "gitlink.org.cn/cloudream/jcs-pub/client/internal/accessstat" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/client/internal/config" "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" @@ -80,7 +81,42 @@ func test(configPath string) { } stgglb.InitLocal(config.Cfg().Local) - stgglb.InitPools(&config.Cfg().HubRPC, &config.Cfg().CoordinatorRPC) + + stgglb.StandaloneMode = config.Cfg().AccessToken == nil + + var accToken *accesstoken.Keeper + if !stgglb.StandaloneMode { + tempCli, err := config.Cfg().CoordinatorRPC.BuildTempClient() + if err != nil { + logger.Warnf("build coordinator rpc temp client: %v", err) + os.Exit(1) + } + + accToken, err = accesstoken.New(*config.Cfg().AccessToken, tempCli) + tempCli.Release() + if err != nil { + logger.Warnf("new access token keeper: %v", err) + os.Exit(1) + } + + hubRPCCfg, err := config.Cfg().HubRPC.Build(accToken) + if err != nil { + logger.Warnf("build hub rpc pool config: %v", err) + os.Exit(1) + } + + corRPCCfg, err := config.Cfg().CoordinatorRPC.Build(accToken) + if err != nil { + logger.Warnf("build coordinator rpc pool config: %v", err) + os.Exit(1) + } + + stgglb.InitPools(hubRPCCfg, corRPCCfg) + } else { + accToken = accesstoken.NewDisabled() + } + accTokenChan := accToken.Start() + defer accToken.Stop() // 数据库 db, err := db.NewDB(&config.Cfg().DB) @@ -140,13 +176,30 @@ func test(configPath string) { os.Exit(0) }() /// 开始监听各个模块的事件 - + accTokenEvt := accTokenChan.Receive() evtPubEvt := evtPubChan.Receive() acStatEvt := acStatChan.Receive() loop: for { select { + case e := <-accTokenEvt.Chan(): + if e.Err != nil { + logger.Errorf("receive access token event: %v", err) + break loop + } + + switch e := e.Value.(type) { + case accesstoken.ExitEvent: + if e.Err != nil { + logger.Errorf("access token keeper exit with error: %v", err) + } else { + logger.Info("access token keeper exited") + } + break loop + } + accTokenEvt = accTokenChan.Receive() + case e := <-evtPubEvt.Chan(): if e.Err != nil { logger.Errorf("receive publisher event: %v", err) diff --git a/client/internal/cmdline/vfstest.go b/client/internal/cmdline/vfstest.go index 89e173e..21ec6c8 100644 --- a/client/internal/cmdline/vfstest.go +++ b/client/internal/cmdline/vfstest.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/cobra" "gitlink.org.cn/cloudream/common/pkgs/logger" "gitlink.org.cn/cloudream/jcs-pub/client/internal/accessstat" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/client/internal/config" "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" @@ -60,7 +61,42 @@ func vfsTest(configPath string, opts serveHTTPOptions) { } stgglb.InitLocal(config.Cfg().Local) - stgglb.InitPools(&config.Cfg().HubRPC, &config.Cfg().CoordinatorRPC) + + stgglb.StandaloneMode = opts.Standalone || config.Cfg().AccessToken == nil + + var accToken *accesstoken.Keeper + if !opts.Standalone { + tempCli, err := config.Cfg().CoordinatorRPC.BuildTempClient() + if err != nil { + logger.Warnf("build coordinator rpc temp client: %v", err) + os.Exit(1) + } + + accToken, err = accesstoken.New(*config.Cfg().AccessToken, tempCli) + tempCli.Release() + if err != nil { + logger.Warnf("new access token keeper: %v", err) + os.Exit(1) + } + + hubRPCCfg, err := config.Cfg().HubRPC.Build(accToken) + if err != nil { + logger.Warnf("build hub rpc pool config: %v", err) + os.Exit(1) + } + + corRPCCfg, err := config.Cfg().CoordinatorRPC.Build(accToken) + if err != nil { + logger.Warnf("build coordinator rpc pool config: %v", err) + os.Exit(1) + } + + stgglb.InitPools(hubRPCCfg, corRPCCfg) + } else { + accToken = accesstoken.NewDisabled() + } + accTokenChan := accToken.Start() + defer accToken.Stop() // 数据库 db, err := db.NewDB(&config.Cfg().DB) @@ -152,6 +188,7 @@ func vfsTest(configPath string, opts serveHTTPOptions) { /// 开始监听各个模块的事件 + accTokenEvt := accTokenChan.Receive() evtPubEvt := evtPubChan.Receive() acStatEvt := acStatChan.Receive() httpEvt := httpChan.Receive() @@ -160,6 +197,23 @@ func vfsTest(configPath string, opts serveHTTPOptions) { loop: for { select { + case e := <-accTokenEvt.Chan(): + if e.Err != nil { + logger.Errorf("receive access token event: %v", err) + break loop + } + + switch e := e.Value.(type) { + case accesstoken.ExitEvent: + if e.Err != nil { + logger.Errorf("access token keeper exit with error: %v", err) + } else { + logger.Info("access token keeper exited") + } + break loop + } + accTokenEvt = accTokenChan.Receive() + case e := <-evtPubEvt.Chan(): if e.Err != nil { logger.Errorf("receive publisher event: %v", err) diff --git a/client/internal/config/config.go b/client/internal/config/config.go index 312ba85..7fdd693 100644 --- a/client/internal/config/config.go +++ b/client/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "gitlink.org.cn/cloudream/common/pkgs/logger" "gitlink.org.cn/cloudream/common/utils/config" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader/strategy" @@ -18,8 +19,8 @@ import ( type Config struct { Local stgglb.LocalMachineInfo `json:"local"` - HubRPC hubrpc.PoolConfig `json:"hubRPC"` - CoordinatorRPC corrpc.PoolConfig `json:"coordinatorRPC"` + HubRPC hubrpc.PoolConfigJSON `json:"hubRPC"` + CoordinatorRPC corrpc.PoolConfigJSON `json:"coordinatorRPC"` Logger logger.Config `json:"logger"` DB db.Config `json:"db"` SysEvent sysevent.Config `json:"sysEvent"` @@ -29,6 +30,7 @@ type Config struct { TickTock ticktock.Config `json:"tickTock"` HTTP *http.Config `json:"http"` Mount *mntcfg.Config `json:"mount"` + AccessToken *accesstoken.Config `json:"accessToken"` } var cfg Config diff --git a/common/assets/confs/client.config.json b/common/assets/confs/client.config.json index dff48db..90e57e6 100644 --- a/common/assets/confs/client.config.json +++ b/common/assets/confs/client.config.json @@ -5,9 +5,16 @@ "externalIP": "127.0.0.1", "locationID": 1 }, - "hubRPC": {}, + "hubRPC": { + "rootCA": "", + "clientCert": "", + "clientKey": "" + }, "coordinatorRPC": { - "address": "127.0.0.1:5009" + "address": "127.0.0.1:5009", + "rootCA": "", + "clientCert": "", + "clientKey": "" }, "logger": { "output": "stdout", @@ -62,5 +69,9 @@ "cacheActiveTime": "1m", "cacheExpireTime": "1m", "scanDataDirInterval": "10m" + }, + "accessToken": { + "account": "", + "password": "" } } \ No newline at end of file diff --git a/common/assets/confs/coordinator.config.json b/common/assets/confs/coordinator.config.json index f906810..2668020 100644 --- a/common/assets/confs/coordinator.config.json +++ b/common/assets/confs/coordinator.config.json @@ -15,6 +15,14 @@ "hubUnavailableTime": "20s" }, "rpc": { - "listen": "127.0.0.1:5009" + "listen": "127.0.0.1:5009", + "rootCA": "", + "serverCert": "", + "serverKey": "" + }, + "hubRPC": { + "rootCA": "", + "clientCert": "", + "clientKey": "" } } \ No newline at end of file diff --git a/common/assets/confs/datamap.config.json b/common/assets/confs/datamap.config.json deleted file mode 100644 index 5d5c206..0000000 --- a/common/assets/confs/datamap.config.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "logger": { - "output": "file", - "outputFileName": "datamap", - "outputDirectory": "log", - "level": "debug" - }, - "db": { - "address": "106.75.6.194:3306", - "account": "root", - "password": "cloudream123456", - "databaseName": "cloudream" - }, - "rabbitMQ": { - "address": "106.75.6.194:5672", - "account": "cloudream", - "password": "123456", - "vhost": "/", - "param": { - "retryNum": 5, - "retryInterval": 5000 - } - } -} \ No newline at end of file diff --git a/common/assets/confs/hub.config.json b/common/assets/confs/hub.config.json index 53422b3..2b03c33 100644 --- a/common/assets/confs/hub.config.json +++ b/common/assets/confs/hub.config.json @@ -6,13 +6,24 @@ "locationID": 1 }, "rpc": { - "listen": "127.0.0.1:5010" + "listen": "127.0.0.1:5010", + "rootCA": "", + "serverCert": "", + "serverKey": "" }, "http": { "listen": "127.0.0.1:5110" }, "coordinatorRPC": { - "address": "127.0.0.1:5009" + "address": "127.0.0.1:5009", + "rootCA": "", + "clientCert": "", + "clientKey": "" + }, + "hubRPC": { + "rootCA": "", + "clientCert": "", + "clientKey": "" }, "logger": { "output": "file", diff --git a/common/pkgs/accesstoken/accesstoken.go b/common/pkgs/accesstoken/accesstoken.go new file mode 100644 index 0000000..314b7c8 --- /dev/null +++ b/common/pkgs/accesstoken/accesstoken.go @@ -0,0 +1,233 @@ +package accesstoken + +import ( + "crypto/ed25519" + "encoding/hex" + "fmt" + "sync" + "time" + + "gitlink.org.cn/cloudream/common/pkgs/async" + "gitlink.org.cn/cloudream/common/pkgs/logger" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" +) + +type CacheEvent interface { + IsAccessTokenCacheEvent() bool +} + +type ExitEvent struct { + CacheEvent + Err error +} + +type CacheKey struct { + UserID cortypes.UserID + TokenID cortypes.AccessTokenID +} + +var ErrTokenNotFound = fmt.Errorf("token not found") + +type AccessTokenLoader func(key CacheKey) (cortypes.UserAccessToken, error) + +type CacheEntry struct { + IsTokenValid bool + Token cortypes.UserAccessToken + PublicKey ed25519.PublicKey + LoadedAt time.Time + LastUsedAt time.Time +} + +type Cache struct { + lock sync.Mutex + cache map[CacheKey]*CacheEntry + done chan any + loader AccessTokenLoader +} + +func New(loader AccessTokenLoader) *Cache { + return &Cache{ + cache: make(map[CacheKey]*CacheEntry), + done: make(chan any, 1), + loader: loader, + } +} + +func (nc *Cache) Start() *async.UnboundChannel[CacheEvent] { + log := logger.WithField("Mod", "AccessTokenCache") + + ch := async.NewUnboundChannel[CacheEvent]() + go func() { + ticker := time.NewTicker(time.Second * 10) + defer ticker.Stop() + + loop: + for { + select { + case <-nc.done: + break loop + + case <-ticker.C: + nc.lock.Lock() + for key, entry := range nc.cache { + if !entry.IsTokenValid { + // 无效Token的记录5分钟后删除 + if time.Since(entry.LoadedAt) > time.Minute*5 { + log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("delete expired invalid token") + delete(nc.cache, key) + continue + } + } else { + // 5分钟没有使用的Token则删除 + if time.Since(entry.LastUsedAt) > time.Minute*5 { + log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("delete unused token") + delete(nc.cache, key) + continue + } + + // 过期Token标记为无效 + if time.Now().After(entry.Token.ExpiresAt) { + entry.IsTokenValid = false + entry.LastUsedAt = time.Now() + log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("token expired") + + } else if time.Since(entry.LoadedAt) > time.Minute*5 { + // 依然有效的Token,则5分钟检查一次有效性 + go nc.load(key) + } + } + } + nc.lock.Unlock() + } + } + + ch.Send(&ExitEvent{}) + }() + return ch +} + +func (mc *Cache) Stop() { + select { + case mc.done <- true: + default: + } +} + +func (mc *Cache) Get(key CacheKey) (*CacheEntry, bool) { + var ret *CacheEntry + var ok bool + + for i := 0; i < 2; i++ { + mc.lock.Lock() + entry, getOk := mc.cache[key] + + if getOk { + ret = entry + ok = true + ret.LastUsedAt = time.Now() + + // 如果Token已经过期,则直接设置为无效Token。因为Token是随机生成的,几乎不可能把一个过期的Token再用上 + if entry.IsTokenValid && time.Now().After(entry.Token.ExpiresAt) { + entry.IsTokenValid = false + entry.LastUsedAt = time.Now() + } + } + + mc.lock.Unlock() + + if ok { + break + } + + mc.load(key) + } + + return ret, ok +} + +func (mc *Cache) NotifyTokenInvalid(key CacheKey) { + log := logger.WithField("Mod", "AccessTokenCache") + + mc.lock.Lock() + defer mc.lock.Unlock() + + entry, ok := mc.cache[key] + if !ok { + entry = &CacheEntry{ + IsTokenValid: false, + LoadedAt: time.Now(), + LastUsedAt: time.Now(), + } + mc.cache[key] = entry + return + } + + entry.IsTokenValid = false + entry.LastUsedAt = time.Now() + + log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("notify token invalid") +} + +func (mc *Cache) load(key CacheKey) { + log := logger.WithField("Mod", "AccessTokenCache") + + loadToken, cerr := mc.loader(key) + + mc.lock.Lock() + defer mc.lock.Unlock() + + if cerr != nil { + log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Warnf("load token: %v", cerr) + + // 明确是无效的Token的也缓存一下,用于快速拒绝请求 + if cerr == ErrTokenNotFound { + mc.cache[key] = &CacheEntry{ + IsTokenValid: false, + LoadedAt: time.Now(), + LastUsedAt: time.Now(), + } + } + return + } + + pubKey, err := hex.DecodeString(loadToken.PublicKey) + if err != nil { + log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Warnf("invalid public key: %v", err) + return + } + + mc.cache[key] = &CacheEntry{ + IsTokenValid: true, + Token: loadToken, + PublicKey: pubKey, + LoadedAt: time.Now(), + LastUsedAt: time.Now(), + } + + log.WithField("UserID", key.UserID).WithField("TokenID", key.TokenID).Infof("load token success, expires at: %v", loadToken.ExpiresAt) +} + +func (mc *Cache) Verify(authInfo rpc.AccessTokenAuthInfo) bool { + token, ok := mc.Get(CacheKey{ + UserID: authInfo.UserID, + TokenID: authInfo.AccessTokenID, + }) + if !ok { + return false + } + if !token.IsTokenValid { + return false + } + + sig, err := hex.DecodeString(authInfo.Signature) + if err != nil { + return false + } + + return ed25519.Verify(token.PublicKey, []byte(MakeStringToSign(authInfo.UserID, authInfo.AccessTokenID, authInfo.Nonce)), []byte(sig)) +} + +func MakeStringToSign(userID cortypes.UserID, tokenID cortypes.AccessTokenID, nonce string) string { + return fmt.Sprintf("%v.%v.%v", userID, tokenID, nonce) +} diff --git a/common/pkgs/rpc/auth.go b/common/pkgs/rpc/auth.go new file mode 100644 index 0000000..4f38a0a --- /dev/null +++ b/common/pkgs/rpc/auth.go @@ -0,0 +1,232 @@ +package rpc + +import ( + "crypto/tls" + "fmt" + "strconv" + + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +const ( + ClientAPISNIV1 = "rpc.client.jcs-pub.v1" + InternalAPISNIV1 = "rpc.internal.jcs-pub.v1" + + MetaUserID = "x-jcs-user-id" + MetaAccessTokenID = "x-jcs-access-token-id" + MetaNonce = "x-jcs-nonce" + MetaSignature = "x-jcs-signature" + MetaTokenAuthInfo = "x-jcs-token-auth-info" +) + +type AccessTokenAuthInfo struct { + UserID cortypes.UserID + AccessTokenID cortypes.AccessTokenID + Nonce string + Signature string +} + +type AccessTokenVerifier interface { + Verify(authInfo AccessTokenAuthInfo) bool +} + +type AccessTokenProvider interface { + GetAuthInfo() (AccessTokenAuthInfo, error) +} + +func (s *ServerBase) tlsConfigSelector(hello *tls.ClientHelloInfo) (*tls.Config, error) { + switch hello.ServerName { + case ClientAPISNIV1: + return &tls.Config{ + Certificates: []tls.Certificate{s.serverCert}, + ClientAuth: tls.NoClientCert, + NextProtos: []string{"h2"}, + }, nil + case InternalAPISNIV1: + return &tls.Config{ + Certificates: []tls.Certificate{s.serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: s.rootCA, + NextProtos: []string{"h2"}, + }, nil + + default: + return nil, fmt.Errorf("unknown server name: %s", hello.ServerName) + } +} + +func (s *ServerBase) authUnary( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, + +) (resp any, err error) { + pr, ok := peer.FromContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "no peer found in context") + } + + tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo) + if !ok { + return nil, status.Error(codes.Unauthenticated, "no tls info found in peer") + } + + // 如果是使用interanl ServerName通过的TLS认证,则直接放行 + if tlsInfo.State.ServerName == InternalAPISNIV1 { + return handler(ctx, req) + } + + // 如果是无需认证的API,则直接放行 + if s.noAuthAPIs[info.FullMethod] { + return handler(ctx, req) + } + + // 否则要进行额外的Token认证 + + if !s.accessTokenAuthAPIs[info.FullMethod] { + return nil, status.Error(codes.Unauthenticated, "unauthorized access") + } + + meta, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "no metadata found in context") + } + + userIDs := meta.Get(MetaUserID) + if len(userIDs) != 1 { + return nil, status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata") + } + userID, err := strconv.ParseInt(userIDs[0], 10, 64) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "invalid user id in metadata") + } + + accessTokenIDs := meta.Get(MetaAccessTokenID) + if len(accessTokenIDs) != 1 { + return nil, status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata") + } + + nonce := meta.Get(MetaNonce) + if len(nonce) != 1 { + return nil, status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata") + } + + signature := meta.Get(MetaSignature) + if len(signature) != 1 { + return nil, status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata") + } + + authInfo := AccessTokenAuthInfo{ + UserID: cortypes.UserID(userID), + AccessTokenID: cortypes.AccessTokenID(accessTokenIDs[0]), + Nonce: nonce[0], + Signature: signature[0], + } + if !s.tokenVerifier.Verify(authInfo) { + return nil, status.Error(codes.Unauthenticated, "invalid access token") + } + + ctx = context.WithValue(ctx, MetaTokenAuthInfo, authInfo) + return handler(ctx, req) +} + +func (s *ServerBase) authStream( + srv any, + stream grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, +) error { + pr, ok := peer.FromContext(stream.Context()) + if !ok { + return status.Error(codes.Unauthenticated, "no peer found in context") + } + + tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo) + if !ok { + return status.Error(codes.Unauthenticated, "no tls info found in peer") + } + + // 如果是使用interanl ServerName通过的TLS认证,则直接放行 + if tlsInfo.State.ServerName == InternalAPISNIV1 { + return handler(srv, stream) + } + + // 如果是无需认证的API,则直接放行 + if s.noAuthAPIs[info.FullMethod] { + return handler(srv, stream) + } + + // 否则要进行额外的Token认证 + + if !s.accessTokenAuthAPIs[info.FullMethod] { + return status.Error(codes.Unauthenticated, "unauthorized access") + } + + meta, ok := metadata.FromIncomingContext(stream.Context()) + if !ok { + return status.Error(codes.Unauthenticated, "no metadata found in context") + } + + userIDs := meta.Get(MetaUserID) + if len(userIDs) != 1 { + return status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata") + } + userID, err := strconv.ParseInt(userIDs[0], 10, 64) + if err != nil { + return status.Error(codes.Unauthenticated, "invalid user id in metadata") + } + + accessTokenIDs := meta.Get(MetaAccessTokenID) + if len(accessTokenIDs) != 1 { + return status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata") + } + + nonce := meta.Get(MetaNonce) + if len(nonce) != 1 { + return status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata") + } + + signature := meta.Get(MetaSignature) + if len(signature) != 1 { + return status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata") + } + + authInfo := AccessTokenAuthInfo{ + UserID: cortypes.UserID(userID), + AccessTokenID: cortypes.AccessTokenID(accessTokenIDs[0]), + Nonce: nonce[0], + Signature: signature[0], + } + + if !s.tokenVerifier.Verify(authInfo) { + return status.Error(codes.Unauthenticated, "invalid access token") + } + + return handler(srv, &serverStream{stream, context.WithValue(stream.Context(), MetaTokenAuthInfo, authInfo)}) +} + +type serverStream struct { + grpc.ServerStream + ctx context.Context +} + +func (s *serverStream) Context() context.Context { + return s.ctx +} + +func GetAuthInfo(ctx context.Context) (AccessTokenAuthInfo, bool) { + val := ctx.Value(MetaTokenAuthInfo) + if val == nil { + return AccessTokenAuthInfo{}, false + } + authInfo, ok := val.(AccessTokenAuthInfo) + return authInfo, ok +} diff --git a/common/pkgs/rpc/coordinator/client.go b/common/pkgs/rpc/coordinator/client.go index ebdfd0b..93bc5ec 100644 --- a/common/pkgs/rpc/coordinator/client.go +++ b/common/pkgs/rpc/coordinator/client.go @@ -14,9 +14,17 @@ type Client struct { func (c *Client) Release() { if c.con != nil { - c.pool.release() + c.pool.connPool.Release(c.pool.cfg.Address) } } +type TempClient struct { + Client +} + +func (c *TempClient) Release() { + c.con.Close() +} + // 客户端的API要和服务端的API保持一致 var _ CoordinatorAPI = (*Client)(nil) diff --git a/common/pkgs/rpc/coordinator/coordinator.pb.go b/common/pkgs/rpc/coordinator/coordinator.pb.go index 183b676..a94901f 100644 --- a/common/pkgs/rpc/coordinator/coordinator.pb.go +++ b/common/pkgs/rpc/coordinator/coordinator.pb.go @@ -27,7 +27,7 @@ var file_pkgs_rpc_coordinator_coordinator_proto_rawDesc = []byte{ 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2f, 0x63, 0x6f, 0x6f, 0x72, 0x64, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, 0x1a, 0x12, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x72, 0x70, 0x63, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xfe, 0x01, 0x0a, 0x0b, 0x43, 0x6f, 0x6f, 0x72, 0x64, 0x69, 0x6e, + 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xb7, 0x03, 0x0a, 0x0b, 0x43, 0x6f, 0x6f, 0x72, 0x64, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x2b, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x48, 0x75, 0x62, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, @@ -43,11 +43,23 @@ var file_pkgs_rpc_coordinator_coordinator_proto_rawDesc = []byte{ 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2f, 0x0a, 0x10, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x53, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x48, 0x75, 0x62, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67, 0x69, 0x74, 0x6c, 0x69, 0x6e, 0x6b, - 0x2e, 0x6f, 0x72, 0x67, 0x2e, 0x63, 0x6e, 0x2f, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x72, 0x65, 0x61, - 0x6d, 0x2f, 0x6a, 0x63, 0x73, 0x2d, 0x70, 0x75, 0x62, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, - 0x2f, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, - 0x3b, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x28, 0x0a, 0x09, 0x55, 0x73, 0x65, 0x72, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x2f, 0x0a, 0x10, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x29, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x72, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x0c, + 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, + 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x31, 0x0a, 0x12, 0x48, + 0x75, 0x62, 0x4c, 0x6f, 0x61, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, + 0x5a, 0x3e, 0x67, 0x69, 0x74, 0x6c, 0x69, 0x6e, 0x6b, 0x2e, 0x6f, 0x72, 0x67, 0x2e, 0x63, 0x6e, + 0x2f, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x72, 0x65, 0x61, 0x6d, 0x2f, 0x6a, 0x63, 0x73, 0x2d, 0x70, + 0x75, 0x62, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, + 0x70, 0x63, 0x2f, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, 0x3b, 0x63, 0x6f, 0x72, 0x72, 0x70, 0x63, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var file_pkgs_rpc_coordinator_coordinator_proto_goTypes = []any{ @@ -60,13 +72,21 @@ var file_pkgs_rpc_coordinator_coordinator_proto_depIdxs = []int32{ 0, // 2: corrpc.Coordinator.GetHubConnectivities:input_type -> rpc.Request 0, // 3: corrpc.Coordinator.ReportHubConnectivity:input_type -> rpc.Request 0, // 4: corrpc.Coordinator.SelectStorageHub:input_type -> rpc.Request - 1, // 5: corrpc.Coordinator.GetHubConfig:output_type -> rpc.Response - 1, // 6: corrpc.Coordinator.GetHubs:output_type -> rpc.Response - 1, // 7: corrpc.Coordinator.GetHubConnectivities:output_type -> rpc.Response - 1, // 8: corrpc.Coordinator.ReportHubConnectivity:output_type -> rpc.Response - 1, // 9: corrpc.Coordinator.SelectStorageHub:output_type -> rpc.Response - 5, // [5:10] is the sub-list for method output_type - 0, // [0:5] is the sub-list for method input_type + 0, // 5: corrpc.Coordinator.UserLogin:input_type -> rpc.Request + 0, // 6: corrpc.Coordinator.UserRefreshToken:input_type -> rpc.Request + 0, // 7: corrpc.Coordinator.UserLogout:input_type -> rpc.Request + 0, // 8: corrpc.Coordinator.HubLoadAccessToken:input_type -> rpc.Request + 1, // 9: corrpc.Coordinator.GetHubConfig:output_type -> rpc.Response + 1, // 10: corrpc.Coordinator.GetHubs:output_type -> rpc.Response + 1, // 11: corrpc.Coordinator.GetHubConnectivities:output_type -> rpc.Response + 1, // 12: corrpc.Coordinator.ReportHubConnectivity:output_type -> rpc.Response + 1, // 13: corrpc.Coordinator.SelectStorageHub:output_type -> rpc.Response + 1, // 14: corrpc.Coordinator.UserLogin:output_type -> rpc.Response + 1, // 15: corrpc.Coordinator.UserRefreshToken:output_type -> rpc.Response + 1, // 16: corrpc.Coordinator.UserLogout:output_type -> rpc.Response + 1, // 17: corrpc.Coordinator.HubLoadAccessToken:output_type -> rpc.Response + 9, // [9:18] is the sub-list for method output_type + 0, // [0:9] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name diff --git a/common/pkgs/rpc/coordinator/coordinator.proto b/common/pkgs/rpc/coordinator/coordinator.proto index c4e565a..40b8039 100644 --- a/common/pkgs/rpc/coordinator/coordinator.proto +++ b/common/pkgs/rpc/coordinator/coordinator.proto @@ -14,4 +14,9 @@ service Coordinator { rpc ReportHubConnectivity(rpc.Request) returns(rpc.Response); rpc SelectStorageHub(rpc.Request) returns(rpc.Response); + + rpc UserLogin(rpc.Request) returns(rpc.Response); + rpc UserRefreshToken(rpc.Request) returns(rpc.Response); + rpc UserLogout(rpc.Request) returns(rpc.Response); + rpc HubLoadAccessToken(rpc.Request) returns(rpc.Response); } \ No newline at end of file diff --git a/common/pkgs/rpc/coordinator/coordinator_grpc.pb.go b/common/pkgs/rpc/coordinator/coordinator_grpc.pb.go index 93a7cf2..173038b 100644 --- a/common/pkgs/rpc/coordinator/coordinator_grpc.pb.go +++ b/common/pkgs/rpc/coordinator/coordinator_grpc.pb.go @@ -25,6 +25,10 @@ const ( Coordinator_GetHubConnectivities_FullMethodName = "/corrpc.Coordinator/GetHubConnectivities" Coordinator_ReportHubConnectivity_FullMethodName = "/corrpc.Coordinator/ReportHubConnectivity" Coordinator_SelectStorageHub_FullMethodName = "/corrpc.Coordinator/SelectStorageHub" + Coordinator_UserLogin_FullMethodName = "/corrpc.Coordinator/UserLogin" + Coordinator_UserRefreshToken_FullMethodName = "/corrpc.Coordinator/UserRefreshToken" + Coordinator_UserLogout_FullMethodName = "/corrpc.Coordinator/UserLogout" + Coordinator_HubLoadAccessToken_FullMethodName = "/corrpc.Coordinator/HubLoadAccessToken" ) // CoordinatorClient is the client API for Coordinator service. @@ -36,6 +40,10 @@ type CoordinatorClient interface { GetHubConnectivities(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) ReportHubConnectivity(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) SelectStorageHub(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) + UserLogin(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) + UserRefreshToken(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) + UserLogout(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) + HubLoadAccessToken(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) } type coordinatorClient struct { @@ -91,6 +99,42 @@ func (c *coordinatorClient) SelectStorageHub(ctx context.Context, in *rpc.Reques return out, nil } +func (c *coordinatorClient) UserLogin(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { + out := new(rpc.Response) + err := c.cc.Invoke(ctx, Coordinator_UserLogin_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *coordinatorClient) UserRefreshToken(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { + out := new(rpc.Response) + err := c.cc.Invoke(ctx, Coordinator_UserRefreshToken_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *coordinatorClient) UserLogout(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { + out := new(rpc.Response) + err := c.cc.Invoke(ctx, Coordinator_UserLogout_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *coordinatorClient) HubLoadAccessToken(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { + out := new(rpc.Response) + err := c.cc.Invoke(ctx, Coordinator_HubLoadAccessToken_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // CoordinatorServer is the server API for Coordinator service. // All implementations must embed UnimplementedCoordinatorServer // for forward compatibility @@ -100,6 +144,10 @@ type CoordinatorServer interface { GetHubConnectivities(context.Context, *rpc.Request) (*rpc.Response, error) ReportHubConnectivity(context.Context, *rpc.Request) (*rpc.Response, error) SelectStorageHub(context.Context, *rpc.Request) (*rpc.Response, error) + UserLogin(context.Context, *rpc.Request) (*rpc.Response, error) + UserRefreshToken(context.Context, *rpc.Request) (*rpc.Response, error) + UserLogout(context.Context, *rpc.Request) (*rpc.Response, error) + HubLoadAccessToken(context.Context, *rpc.Request) (*rpc.Response, error) mustEmbedUnimplementedCoordinatorServer() } @@ -122,6 +170,18 @@ func (UnimplementedCoordinatorServer) ReportHubConnectivity(context.Context, *rp func (UnimplementedCoordinatorServer) SelectStorageHub(context.Context, *rpc.Request) (*rpc.Response, error) { return nil, status.Errorf(codes.Unimplemented, "method SelectStorageHub not implemented") } +func (UnimplementedCoordinatorServer) UserLogin(context.Context, *rpc.Request) (*rpc.Response, error) { + return nil, status.Errorf(codes.Unimplemented, "method UserLogin not implemented") +} +func (UnimplementedCoordinatorServer) UserRefreshToken(context.Context, *rpc.Request) (*rpc.Response, error) { + return nil, status.Errorf(codes.Unimplemented, "method UserRefreshToken not implemented") +} +func (UnimplementedCoordinatorServer) UserLogout(context.Context, *rpc.Request) (*rpc.Response, error) { + return nil, status.Errorf(codes.Unimplemented, "method UserLogout not implemented") +} +func (UnimplementedCoordinatorServer) HubLoadAccessToken(context.Context, *rpc.Request) (*rpc.Response, error) { + return nil, status.Errorf(codes.Unimplemented, "method HubLoadAccessToken not implemented") +} func (UnimplementedCoordinatorServer) mustEmbedUnimplementedCoordinatorServer() {} // UnsafeCoordinatorServer may be embedded to opt out of forward compatibility for this service. @@ -225,6 +285,78 @@ func _Coordinator_SelectStorageHub_Handler(srv interface{}, ctx context.Context, return interceptor(ctx, in, info, handler) } +func _Coordinator_UserLogin_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(rpc.Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CoordinatorServer).UserLogin(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Coordinator_UserLogin_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CoordinatorServer).UserLogin(ctx, req.(*rpc.Request)) + } + return interceptor(ctx, in, info, handler) +} + +func _Coordinator_UserRefreshToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(rpc.Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CoordinatorServer).UserRefreshToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Coordinator_UserRefreshToken_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CoordinatorServer).UserRefreshToken(ctx, req.(*rpc.Request)) + } + return interceptor(ctx, in, info, handler) +} + +func _Coordinator_UserLogout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(rpc.Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CoordinatorServer).UserLogout(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Coordinator_UserLogout_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CoordinatorServer).UserLogout(ctx, req.(*rpc.Request)) + } + return interceptor(ctx, in, info, handler) +} + +func _Coordinator_HubLoadAccessToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(rpc.Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CoordinatorServer).HubLoadAccessToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Coordinator_HubLoadAccessToken_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CoordinatorServer).HubLoadAccessToken(ctx, req.(*rpc.Request)) + } + return interceptor(ctx, in, info, handler) +} + // Coordinator_ServiceDesc is the grpc.ServiceDesc for Coordinator service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -252,6 +384,22 @@ var Coordinator_ServiceDesc = grpc.ServiceDesc{ MethodName: "SelectStorageHub", Handler: _Coordinator_SelectStorageHub_Handler, }, + { + MethodName: "UserLogin", + Handler: _Coordinator_UserLogin_Handler, + }, + { + MethodName: "UserRefreshToken", + Handler: _Coordinator_UserRefreshToken_Handler, + }, + { + MethodName: "UserLogout", + Handler: _Coordinator_UserLogout_Handler, + }, + { + MethodName: "HubLoadAccessToken", + Handler: _Coordinator_HubLoadAccessToken_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "pkgs/rpc/coordinator/coordinator.proto", diff --git a/common/pkgs/rpc/coordinator/pool.go b/common/pkgs/rpc/coordinator/pool.go index 829c077..8c3b59a 100644 --- a/common/pkgs/rpc/coordinator/pool.go +++ b/common/pkgs/rpc/coordinator/pool.go @@ -1,100 +1,115 @@ package corrpc import ( - "sync" - "time" + "crypto/tls" + "crypto/x509" + "fmt" + "os" "gitlink.org.cn/cloudream/common/consts/errorcode" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" - grpc "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) type PoolConfig struct { - Address string `json:"address"` + Address string + Conn rpc.PoolConfig +} + +type PoolConfigJSON struct { + Address string `json:"address"` + RootCA string `json:"rootCA"` + ClientCert string `json:"clientCert"` + ClientKey string `json:"clientKey"` +} + +func (c *PoolConfigJSON) Build(tokenProv rpc.AccessTokenProvider) (*PoolConfig, error) { + pc := &PoolConfig{ + Address: c.Address, + } + pc.Conn.AccessTokenProvider = tokenProv + + rootCA, err := os.ReadFile(c.RootCA) + if err != nil { + return nil, fmt.Errorf("load root ca: %v", err) + } + pc.Conn.RootCA = x509.NewCertPool() + if !pc.Conn.RootCA.AppendCertsFromPEM(rootCA) { + return nil, fmt.Errorf("failed to parse root ca") + } + + if c.ClientCert != "" && c.ClientKey != "" { + cert, err := tls.LoadX509KeyPair(c.ClientCert, c.ClientKey) + if err != nil { + return nil, fmt.Errorf("load client cert: %v", err) + } + pc.Conn.ClientCert = &cert + } else if tokenProv == nil { + return nil, fmt.Errorf("must provide client cert or access token provider") + } + + return pc, nil +} + +func (c *PoolConfigJSON) BuildTempClient() (*TempClient, error) { + rootCA, err := os.ReadFile(c.RootCA) + if err != nil { + return nil, fmt.Errorf("load root ca: %v", err) + } + rootCAs := x509.NewCertPool() + if !rootCAs.AppendCertsFromPEM(rootCA) { + return nil, fmt.Errorf("failed to parse root ca") + } + + gcon, err := grpc.NewClient(c.Address, + grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + RootCAs: rootCAs, + ServerName: rpc.ClientAPISNIV1, + NextProtos: []string{"h2"}, + })), + ) + if err != nil { + return nil, err + } + + return &TempClient{ + Client: Client{ + con: gcon, + cli: NewCoordinatorClient(gcon), + pool: nil, + fusedErr: nil, + }, + }, nil } type Pool struct { - cfg PoolConfig - grpcCon *grpcCon - lock sync.Mutex -} - -type grpcCon struct { - grpcCon *grpc.ClientConn - refCount int - stopClosing chan any + cfg PoolConfig + connPool *rpc.ConnPool } func NewPool(cfg PoolConfig) *Pool { return &Pool{ - cfg: cfg, + cfg: cfg, + connPool: rpc.NewConnPool(cfg.Conn), } } func (p *Pool) Get() *Client { - p.lock.Lock() - defer p.lock.Unlock() - - con := p.grpcCon - if con == nil { - gcon, err := grpc.NewClient(p.cfg.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - return &Client{ - con: nil, - pool: p, - fusedErr: rpc.Failed(errorcode.OperationFailed, err.Error()), - } + con, err := p.connPool.GetConnection(p.cfg.Address) + if err != nil { + return &Client{ + con: nil, + cli: nil, + pool: p, + fusedErr: rpc.Failed(errorcode.OperationFailed, err.Error()), } - - con = &grpcCon{ - grpcCon: gcon, - refCount: 0, - stopClosing: nil, - } - - p.grpcCon = con - - } else if con.stopClosing != nil { - close(con.stopClosing) - con.stopClosing = nil } - con.refCount++ - return &Client{ - con: con.grpcCon, - cli: NewCoordinatorClient(con.grpcCon), - pool: p, - } -} - -func (p *Pool) release() { - p.lock.Lock() - defer p.lock.Unlock() - - grpcCon := p.grpcCon - grpcCon.refCount-- - grpcCon.refCount = max(grpcCon.refCount, 0) - - if grpcCon.refCount == 0 { - stopClosing := make(chan any) - grpcCon.stopClosing = stopClosing - - go func() { - select { - case <-stopClosing: - return - - case <-time.After(time.Minute): - p.lock.Lock() - defer p.lock.Unlock() - - if p.grpcCon.refCount == 0 { - p.grpcCon.grpcCon.Close() - p.grpcCon = nil - } - } - }() + con: con, + cli: NewCoordinatorClient(con), + pool: p, + fusedErr: nil, } } diff --git a/common/pkgs/rpc/coordinator/server.go b/common/pkgs/rpc/coordinator/server.go index d4276ad..3a659e3 100644 --- a/common/pkgs/rpc/coordinator/server.go +++ b/common/pkgs/rpc/coordinator/server.go @@ -7,6 +7,7 @@ import ( type CoordinatorAPI interface { HubService StorageService + UserService } type Server struct { @@ -15,12 +16,26 @@ type Server struct { svrImpl CoordinatorAPI } -func NewServer(cfg rpc.Config, impl CoordinatorAPI) *Server { +func NewServer(cfg rpc.Config, impl CoordinatorAPI, tokenVerifier rpc.AccessTokenVerifier) *Server { svr := &Server{ svrImpl: impl, } - svr.ServerBase = rpc.NewServerBase(cfg, svr, &Coordinator_ServiceDesc) + svr.ServerBase = rpc.NewServerBase(cfg, svr, &Coordinator_ServiceDesc, tokenAuthAPIs, tokenVerifier, noAuthAPIs) return svr } var _ CoordinatorServer = (*Server)(nil) + +var tokenAuthAPIs []string + +func TokenAuth(api string) bool { + tokenAuthAPIs = append(tokenAuthAPIs, api) + return true +} + +var noAuthAPIs []string + +func NoAuth(api string) bool { + noAuthAPIs = append(noAuthAPIs, api) + return true +} diff --git a/common/pkgs/rpc/coordinator/storage.go b/common/pkgs/rpc/coordinator/storage.go index d6e4c4d..19c27bd 100644 --- a/common/pkgs/rpc/coordinator/storage.go +++ b/common/pkgs/rpc/coordinator/storage.go @@ -19,6 +19,8 @@ type SelectStorageHubResp struct { Hubs []*cortypes.Hub } +var _ = TokenAuth(Coordinator_SelectStorageHub_FullMethodName) + func (c *Client) SelectStorageHub(ctx context.Context, msg *SelectStorageHub) (*SelectStorageHubResp, *rpc.CodeError) { if c.fusedErr != nil { return nil, c.fusedErr diff --git a/common/pkgs/rpc/coordinator/user.go b/common/pkgs/rpc/coordinator/user.go new file mode 100644 index 0000000..2d44a5b --- /dev/null +++ b/common/pkgs/rpc/coordinator/user.go @@ -0,0 +1,95 @@ +package corrpc + +import ( + context "context" + + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" +) + +type UserService interface { + UserLogin(ctx context.Context, msg *UserLogin) (*UserLoginResp, *rpc.CodeError) + + UserLogout(ctx context.Context, msg *UserLogout) (*UserLogoutResp, *rpc.CodeError) + + UserRefreshToken(ctx context.Context, msg *UserRefreshToken) (*UserRefreshTokenResp, *rpc.CodeError) + + HubLoadAccessToken(ctx context.Context, msg *HubLoadAccessToken) (*HubLoadAccessTokenResp, *rpc.CodeError) +} + +// 客户端登录 +type UserLogin struct { + Account string + Password string +} +type UserLoginResp struct { + Token cortypes.UserAccessToken + PrivateKey string +} + +var _ = NoAuth(Coordinator_UserLogin_FullMethodName) + +func (c *Client) UserLogin(ctx context.Context, msg *UserLogin) (*UserLoginResp, *rpc.CodeError) { + if c.fusedErr != nil { + return nil, c.fusedErr + } + return rpc.UnaryClient[*UserLoginResp](c.cli.UserLogin, ctx, msg) +} +func (s *Server) UserLogin(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { + return rpc.UnaryServer(s.svrImpl.UserLogin, ctx, req) +} + +// 客户端刷新Token,原始Token会继续有效。 +type UserRefreshToken struct{} +type UserRefreshTokenResp struct { + Token cortypes.UserAccessToken + PrivateKey string +} + +var _ = TokenAuth(Coordinator_UserLogin_FullMethodName) + +func (c *Client) UserRefreshToken(ctx context.Context, msg *UserRefreshToken) (*UserRefreshTokenResp, *rpc.CodeError) { + if c.fusedErr != nil { + return nil, c.fusedErr + } + return rpc.UnaryClient[*UserRefreshTokenResp](c.cli.UserRefreshToken, ctx, msg) +} +func (s *Server) UserRefreshToken(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { + return rpc.UnaryServer(s.svrImpl.UserRefreshToken, ctx, req) +} + +// 客户端登出。会使用GRPC元数据中的TokenID和UserID来查找Token并删除。 +type UserLogout struct{} +type UserLogoutResp struct{} + +var _ = TokenAuth(Coordinator_UserLogout_FullMethodName) + +func (c *Client) UserLogout(ctx context.Context, msg *UserLogout) (*UserLogoutResp, *rpc.CodeError) { + if c.fusedErr != nil { + return nil, c.fusedErr + } + return rpc.UnaryClient[*UserLogoutResp](c.cli.UserLogout, ctx, msg) +} +func (s *Server) UserLogout(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { + return rpc.UnaryServer(s.svrImpl.UserLogout, ctx, req) +} + +// Hub服务加载AccessToken +type HubLoadAccessToken struct { + HubID cortypes.HubID + UserID cortypes.UserID + TokenID cortypes.AccessTokenID +} +type HubLoadAccessTokenResp struct { + Token cortypes.UserAccessToken +} + +func (c *Client) HubLoadAccessToken(ctx context.Context, msg *HubLoadAccessToken) (*HubLoadAccessTokenResp, *rpc.CodeError) { + if c.fusedErr != nil { + return nil, c.fusedErr + } + return rpc.UnaryClient[*HubLoadAccessTokenResp](c.cli.HubLoadAccessToken, ctx, msg) +} +func (s *Server) HubLoadAccessToken(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { + return rpc.UnaryServer(s.svrImpl.HubLoadAccessToken, ctx, req) +} diff --git a/common/pkgs/rpc/hub/client.go b/common/pkgs/rpc/hub/client.go index 441afb2..a029c41 100644 --- a/common/pkgs/rpc/hub/client.go +++ b/common/pkgs/rpc/hub/client.go @@ -6,7 +6,7 @@ import ( ) type Client struct { - addr grpcAddr + addr string con *grpc.ClientConn cli HubClient pool *Pool @@ -15,7 +15,7 @@ type Client struct { func (c *Client) Release() { if c.con != nil { - c.pool.release(c.addr) + c.pool.connPool.Release(c.addr) } } diff --git a/common/pkgs/rpc/hub/hub.pb.go b/common/pkgs/rpc/hub/hub.pb.go index 9f92568..eb7a201 100644 --- a/common/pkgs/rpc/hub/hub.pb.go +++ b/common/pkgs/rpc/hub/hub.pb.go @@ -26,7 +26,7 @@ var file_pkgs_rpc_hub_hub_proto_rawDesc = []byte{ 0x0a, 0x16, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x68, 0x75, 0x62, 0x2f, 0x68, 0x75, 0x62, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x68, 0x75, 0x62, 0x72, 0x70, 0x63, 0x1a, 0x12, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, 0x2f, 0x72, 0x70, 0x63, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xb8, 0x02, 0x0a, 0x03, 0x48, 0x75, 0x62, 0x12, 0x2c, 0x0a, 0x0d, + 0x72, 0x6f, 0x74, 0x6f, 0x32, 0xf5, 0x02, 0x0a, 0x03, 0x48, 0x75, 0x62, 0x12, 0x2c, 0x0a, 0x0d, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x49, 0x4f, 0x50, 0x6c, 0x61, 0x6e, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x31, 0x0a, 0x0c, 0x53, 0x65, @@ -45,12 +45,16 @@ var file_pkgs_rpc_hub_hub_proto_rawDesc = []byte{ 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x27, 0x0a, 0x08, 0x47, 0x65, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, - 0x40, 0x5a, 0x3e, 0x67, 0x69, 0x74, 0x6c, 0x69, 0x6e, 0x6b, 0x2e, 0x6f, 0x72, 0x67, 0x2e, 0x63, - 0x6e, 0x2f, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x72, 0x65, 0x61, 0x6d, 0x2f, 0x6a, 0x63, 0x73, 0x2d, - 0x70, 0x75, 0x62, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x73, 0x2f, - 0x72, 0x70, 0x63, 0x2f, 0x68, 0x75, 0x62, 0x72, 0x70, 0x63, 0x3b, 0x68, 0x75, 0x62, 0x72, 0x70, - 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x1a, 0x0d, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x3b, 0x0a, 0x1c, 0x4e, 0x6f, 0x74, 0x69, 0x66, 0x79, 0x55, 0x73, 0x65, 0x72, 0x41, 0x63, 0x63, + 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x49, 0x6e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, + 0x0c, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, + 0x72, 0x70, 0x63, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, + 0x67, 0x69, 0x74, 0x6c, 0x69, 0x6e, 0x6b, 0x2e, 0x6f, 0x72, 0x67, 0x2e, 0x63, 0x6e, 0x2f, 0x63, + 0x6c, 0x6f, 0x75, 0x64, 0x72, 0x65, 0x61, 0x6d, 0x2f, 0x6a, 0x63, 0x73, 0x2d, 0x70, 0x75, 0x62, + 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x73, 0x2f, 0x72, 0x70, 0x63, + 0x2f, 0x68, 0x75, 0x62, 0x72, 0x70, 0x63, 0x3b, 0x68, 0x75, 0x62, 0x72, 0x70, 0x63, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var file_pkgs_rpc_hub_hub_proto_goTypes = []any{ @@ -66,15 +70,17 @@ var file_pkgs_rpc_hub_hub_proto_depIdxs = []int32{ 0, // 4: hubrpc.Hub.GetIOVar:input_type -> rpc.Request 0, // 5: hubrpc.Hub.Ping:input_type -> rpc.Request 0, // 6: hubrpc.Hub.GetState:input_type -> rpc.Request - 2, // 7: hubrpc.Hub.ExecuteIOPlan:output_type -> rpc.Response - 2, // 8: hubrpc.Hub.SendIOStream:output_type -> rpc.Response - 1, // 9: hubrpc.Hub.GetIOStream:output_type -> rpc.ChunkedData - 2, // 10: hubrpc.Hub.SendIOVar:output_type -> rpc.Response - 2, // 11: hubrpc.Hub.GetIOVar:output_type -> rpc.Response - 2, // 12: hubrpc.Hub.Ping:output_type -> rpc.Response - 2, // 13: hubrpc.Hub.GetState:output_type -> rpc.Response - 7, // [7:14] is the sub-list for method output_type - 0, // [0:7] is the sub-list for method input_type + 0, // 7: hubrpc.Hub.NotifyUserAccessTokenInvalid:input_type -> rpc.Request + 2, // 8: hubrpc.Hub.ExecuteIOPlan:output_type -> rpc.Response + 2, // 9: hubrpc.Hub.SendIOStream:output_type -> rpc.Response + 1, // 10: hubrpc.Hub.GetIOStream:output_type -> rpc.ChunkedData + 2, // 11: hubrpc.Hub.SendIOVar:output_type -> rpc.Response + 2, // 12: hubrpc.Hub.GetIOVar:output_type -> rpc.Response + 2, // 13: hubrpc.Hub.Ping:output_type -> rpc.Response + 2, // 14: hubrpc.Hub.GetState:output_type -> rpc.Response + 2, // 15: hubrpc.Hub.NotifyUserAccessTokenInvalid:output_type -> rpc.Response + 8, // [8:16] is the sub-list for method output_type + 0, // [0:8] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name diff --git a/common/pkgs/rpc/hub/hub.proto b/common/pkgs/rpc/hub/hub.proto index 2d5f5c4..6144e49 100644 --- a/common/pkgs/rpc/hub/hub.proto +++ b/common/pkgs/rpc/hub/hub.proto @@ -16,4 +16,6 @@ service Hub { rpc Ping(rpc.Request) returns(rpc.Response); rpc GetState(rpc.Request) returns(rpc.Response); + + rpc NotifyUserAccessTokenInvalid(rpc.Request) returns(rpc.Response); } \ No newline at end of file diff --git a/common/pkgs/rpc/hub/hub_grpc.pb.go b/common/pkgs/rpc/hub/hub_grpc.pb.go index 177a722..ec65369 100644 --- a/common/pkgs/rpc/hub/hub_grpc.pb.go +++ b/common/pkgs/rpc/hub/hub_grpc.pb.go @@ -20,13 +20,14 @@ import ( const _ = grpc.SupportPackageIsVersion7 const ( - Hub_ExecuteIOPlan_FullMethodName = "/hubrpc.Hub/ExecuteIOPlan" - Hub_SendIOStream_FullMethodName = "/hubrpc.Hub/SendIOStream" - Hub_GetIOStream_FullMethodName = "/hubrpc.Hub/GetIOStream" - Hub_SendIOVar_FullMethodName = "/hubrpc.Hub/SendIOVar" - Hub_GetIOVar_FullMethodName = "/hubrpc.Hub/GetIOVar" - Hub_Ping_FullMethodName = "/hubrpc.Hub/Ping" - Hub_GetState_FullMethodName = "/hubrpc.Hub/GetState" + Hub_ExecuteIOPlan_FullMethodName = "/hubrpc.Hub/ExecuteIOPlan" + Hub_SendIOStream_FullMethodName = "/hubrpc.Hub/SendIOStream" + Hub_GetIOStream_FullMethodName = "/hubrpc.Hub/GetIOStream" + Hub_SendIOVar_FullMethodName = "/hubrpc.Hub/SendIOVar" + Hub_GetIOVar_FullMethodName = "/hubrpc.Hub/GetIOVar" + Hub_Ping_FullMethodName = "/hubrpc.Hub/Ping" + Hub_GetState_FullMethodName = "/hubrpc.Hub/GetState" + Hub_NotifyUserAccessTokenInvalid_FullMethodName = "/hubrpc.Hub/NotifyUserAccessTokenInvalid" ) // HubClient is the client API for Hub service. @@ -40,6 +41,7 @@ type HubClient interface { GetIOVar(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) Ping(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) GetState(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) + NotifyUserAccessTokenInvalid(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) } type hubClient struct { @@ -161,6 +163,15 @@ func (c *hubClient) GetState(ctx context.Context, in *rpc.Request, opts ...grpc. return out, nil } +func (c *hubClient) NotifyUserAccessTokenInvalid(ctx context.Context, in *rpc.Request, opts ...grpc.CallOption) (*rpc.Response, error) { + out := new(rpc.Response) + err := c.cc.Invoke(ctx, Hub_NotifyUserAccessTokenInvalid_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // HubServer is the server API for Hub service. // All implementations must embed UnimplementedHubServer // for forward compatibility @@ -172,6 +183,7 @@ type HubServer interface { GetIOVar(context.Context, *rpc.Request) (*rpc.Response, error) Ping(context.Context, *rpc.Request) (*rpc.Response, error) GetState(context.Context, *rpc.Request) (*rpc.Response, error) + NotifyUserAccessTokenInvalid(context.Context, *rpc.Request) (*rpc.Response, error) mustEmbedUnimplementedHubServer() } @@ -200,6 +212,9 @@ func (UnimplementedHubServer) Ping(context.Context, *rpc.Request) (*rpc.Response func (UnimplementedHubServer) GetState(context.Context, *rpc.Request) (*rpc.Response, error) { return nil, status.Errorf(codes.Unimplemented, "method GetState not implemented") } +func (UnimplementedHubServer) NotifyUserAccessTokenInvalid(context.Context, *rpc.Request) (*rpc.Response, error) { + return nil, status.Errorf(codes.Unimplemented, "method NotifyUserAccessTokenInvalid not implemented") +} func (UnimplementedHubServer) mustEmbedUnimplementedHubServer() {} // UnsafeHubServer may be embedded to opt out of forward compatibility for this service. @@ -350,6 +365,24 @@ func _Hub_GetState_Handler(srv interface{}, ctx context.Context, dec func(interf return interceptor(ctx, in, info, handler) } +func _Hub_NotifyUserAccessTokenInvalid_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(rpc.Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HubServer).NotifyUserAccessTokenInvalid(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Hub_NotifyUserAccessTokenInvalid_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HubServer).NotifyUserAccessTokenInvalid(ctx, req.(*rpc.Request)) + } + return interceptor(ctx, in, info, handler) +} + // Hub_ServiceDesc is the grpc.ServiceDesc for Hub service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -377,6 +410,10 @@ var Hub_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetState", Handler: _Hub_GetState_Handler, }, + { + MethodName: "NotifyUserAccessTokenInvalid", + Handler: _Hub_NotifyUserAccessTokenInvalid_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/common/pkgs/rpc/hub/ioswitch.go b/common/pkgs/rpc/hub/ioswitch.go index ccd1185..a6f5bbf 100644 --- a/common/pkgs/rpc/hub/ioswitch.go +++ b/common/pkgs/rpc/hub/ioswitch.go @@ -22,6 +22,8 @@ type ExecuteIOPlan struct { } type ExecuteIOPlanResp struct{} +var _ = TokenAuth(Hub_ExecuteIOPlan_FullMethodName) + func (c *Client) ExecuteIOPlan(ctx context.Context, req *ExecuteIOPlan) (*ExecuteIOPlanResp, *rpc.CodeError) { if c.fusedErr != nil { return nil, c.fusedErr @@ -49,6 +51,8 @@ func (s *SendIOStream) SetStream(str io.Reader) { type SendIOStreamResp struct{} +var _ = TokenAuth(Hub_SendIOStream_FullMethodName) + func (c *Client) SendIOStream(ctx context.Context, req *SendIOStream) (*SendIOStreamResp, *rpc.CodeError) { if c.fusedErr != nil { return nil, c.fusedErr @@ -71,6 +75,8 @@ type GetIOStreamResp struct { Stream io.ReadCloser `json:"-"` } +var _ = TokenAuth(Hub_GetIOStream_FullMethodName) + func (r *GetIOStreamResp) GetStream() io.ReadCloser { return r.Stream } @@ -97,6 +103,8 @@ type SendIOVar struct { } type SendIOVarResp struct{} +var _ = TokenAuth(Hub_SendIOVar_FullMethodName) + func (c *Client) SendIOVar(ctx context.Context, req *SendIOVar) (*SendIOVarResp, *rpc.CodeError) { if c.fusedErr != nil { return nil, c.fusedErr @@ -119,6 +127,8 @@ type GetIOVarResp struct { Value exec.VarValue } +var _ = TokenAuth(Hub_GetIOVar_FullMethodName) + func (c *Client) GetIOVar(ctx context.Context, req *GetIOVar) (*GetIOVarResp, *rpc.CodeError) { if c.fusedErr != nil { return nil, c.fusedErr diff --git a/common/pkgs/rpc/hub/mics.go b/common/pkgs/rpc/hub/mics.go index e25ad0b..85c3fb0 100644 --- a/common/pkgs/rpc/hub/mics.go +++ b/common/pkgs/rpc/hub/mics.go @@ -15,6 +15,8 @@ type MicsSvc interface { type Ping struct{} type PingResp struct{} +var _ = TokenAuth(Hub_Ping_FullMethodName) + func (c *Client) Ping(ctx context.Context, req *Ping) (*PingResp, *rpc.CodeError) { if c.fusedErr != nil { return nil, c.fusedErr diff --git a/common/pkgs/rpc/hub/pool.go b/common/pkgs/rpc/hub/pool.go index 6710a94..515f1ff 100644 --- a/common/pkgs/rpc/hub/pool.go +++ b/common/pkgs/rpc/hub/pool.go @@ -1,114 +1,77 @@ package hubrpc import ( + "crypto/tls" + "crypto/x509" "fmt" - "sync" - "time" + "os" "gitlink.org.cn/cloudream/common/consts/errorcode" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" - grpc "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) -type PoolConfig struct{} +type PoolConfig struct { + Conn rpc.PoolConfig +} + +type PoolConfigJSON struct { + RootCA string `json:"rootCA"` + ClientCert string `json:"clientCert"` + ClientKey string `json:"clientKey"` +} + +func (c *PoolConfigJSON) Build(tokenProv rpc.AccessTokenProvider) (*PoolConfig, error) { + pc := &PoolConfig{} + pc.Conn.AccessTokenProvider = tokenProv + + rootCA, err := os.ReadFile(c.RootCA) + if err != nil { + return nil, fmt.Errorf("load root ca: %v", err) + } + pc.Conn.RootCA = x509.NewCertPool() + if !pc.Conn.RootCA.AppendCertsFromPEM(rootCA) { + return nil, fmt.Errorf("failed to parse root ca") + } + + if c.ClientCert != "" && c.ClientKey != "" { + cert, err := tls.LoadX509KeyPair(c.ClientCert, c.ClientKey) + if err != nil { + return nil, fmt.Errorf("load client cert: %v", err) + } + pc.Conn.ClientCert = &cert + } else if tokenProv == nil { + return nil, fmt.Errorf("must provide client cert or access token provider") + } + + return pc, nil +} type Pool struct { - grpcCons map[grpcAddr]*grpcCon - lock sync.Mutex -} - -type grpcAddr struct { - IP string - Port int -} - -type grpcCon struct { - grpcCon *grpc.ClientConn - refCount int - stopClosing chan any + connPool *rpc.ConnPool } func NewPool(cfg PoolConfig) *Pool { return &Pool{ - grpcCons: make(map[grpcAddr]*grpcCon), + connPool: rpc.NewConnPool(cfg.Conn), } } func (p *Pool) Get(ip string, port int) *Client { - p.lock.Lock() - defer p.lock.Unlock() - - ga := grpcAddr{IP: ip, Port: port} - con := p.grpcCons[ga] - if con == nil { - gcon, err := grpc.NewClient(fmt.Sprintf("%v:%v", ip, port), grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - return &Client{ - addr: ga, - con: nil, - pool: p, - fusedErr: rpc.Failed(errorcode.OperationFailed, err.Error()), - } + addr := fmt.Sprintf("%s:%d", ip, port) + con, err := p.connPool.GetConnection(addr) + if err != nil { + return &Client{ + addr: addr, + con: nil, + pool: p, + fusedErr: rpc.Failed(errorcode.OperationFailed, err.Error()), } - - con = &grpcCon{ - grpcCon: gcon, - refCount: 0, - stopClosing: nil, - } - - p.grpcCons[ga] = con - } else if con.stopClosing != nil { - close(con.stopClosing) - con.stopClosing = nil } - con.refCount++ - return &Client{ - addr: ga, - con: con.grpcCon, - cli: NewHubClient(con.grpcCon), + addr: addr, + con: con, + cli: NewHubClient(con), pool: p, } } - -func (p *Pool) release(addr grpcAddr) { - p.lock.Lock() - defer p.lock.Unlock() - - grpcCon := p.grpcCons[addr] - if grpcCon == nil { - return - } - - grpcCon.refCount-- - grpcCon.refCount = max(grpcCon.refCount, 0) - - if grpcCon.refCount == 0 { - stopClosing := make(chan any) - grpcCon.stopClosing = stopClosing - - go func() { - select { - case <-stopClosing: - return - - case <-time.After(time.Minute): - p.lock.Lock() - defer p.lock.Unlock() - - grpcCon := p.grpcCons[addr] - if grpcCon == nil { - return - } - - if grpcCon.refCount == 0 { - grpcCon.grpcCon.Close() - delete(p.grpcCons, addr) - } - } - }() - } -} diff --git a/common/pkgs/rpc/hub/server.go b/common/pkgs/rpc/hub/server.go index a741388..89b6bf7 100644 --- a/common/pkgs/rpc/hub/server.go +++ b/common/pkgs/rpc/hub/server.go @@ -8,7 +8,7 @@ type HubAPI interface { // CacheSvc IOSwitchSvc MicsSvc - // UserSpaceSvc + UserSvc } type Server struct { @@ -17,12 +17,26 @@ type Server struct { svrImpl HubAPI } -func NewServer(cfg rpc.Config, impl HubAPI) *Server { +func NewServer(cfg rpc.Config, impl HubAPI, tokenVerifier rpc.AccessTokenVerifier) *Server { svr := &Server{ svrImpl: impl, } - svr.ServerBase = rpc.NewServerBase(cfg, svr, &Hub_ServiceDesc) + svr.ServerBase = rpc.NewServerBase(cfg, svr, &Hub_ServiceDesc, tokenAuthAPIs, tokenVerifier, noAuthAPIs) return svr } var _ HubServer = (*Server)(nil) + +var tokenAuthAPIs []string + +func TokenAuth(api string) bool { + tokenAuthAPIs = append(tokenAuthAPIs, api) + return true +} + +var noAuthAPIs []string + +func NoAuth(api string) bool { + noAuthAPIs = append(noAuthAPIs, api) + return true +} diff --git a/common/pkgs/rpc/hub/user.go b/common/pkgs/rpc/hub/user.go new file mode 100644 index 0000000..d067f6e --- /dev/null +++ b/common/pkgs/rpc/hub/user.go @@ -0,0 +1,29 @@ +package hubrpc + +import ( + context "context" + + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" +) + +type UserSvc interface { + NotifyUserAccessTokenInvalid(ctx context.Context, req *NotifyUserAccessTokenInvalid) (*NotifyUserAccessTokenInvalidResp, *rpc.CodeError) +} + +// 通知用户的Token登出 +type NotifyUserAccessTokenInvalid struct { + UserID cortypes.UserID + TokenID cortypes.AccessTokenID +} +type NotifyUserAccessTokenInvalidResp struct{} + +func (c *Client) NotifyUserAccessTokenInvalid(ctx context.Context, req *NotifyUserAccessTokenInvalid) (*NotifyUserAccessTokenInvalidResp, *rpc.CodeError) { + if c.fusedErr != nil { + return nil, c.fusedErr + } + return rpc.UnaryClient[*NotifyUserAccessTokenInvalidResp](c.cli.NotifyUserAccessTokenInvalid, ctx, req) +} +func (s *Server) NotifyUserAccessTokenInvalid(ctx context.Context, req *rpc.Request) (*rpc.Response, error) { + return rpc.UnaryServer(s.svrImpl.NotifyUserAccessTokenInvalid, ctx, req) +} diff --git a/common/pkgs/rpc/pool.go b/common/pkgs/rpc/pool.go new file mode 100644 index 0000000..7f2db99 --- /dev/null +++ b/common/pkgs/rpc/pool.go @@ -0,0 +1,176 @@ +package rpc + +import ( + context "context" + "crypto/tls" + "crypto/x509" + "fmt" + "sync" + "time" + + grpc "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" +) + +type PoolConfig struct { + RootCA *x509.CertPool + // 客户端证书,与AccessTokenProvider二选一 + ClientCert *tls.Certificate + // AccessTokenProvider,与ClientCert二选一 + AccessTokenProvider AccessTokenProvider +} + +type ConnPool struct { + cfg PoolConfig + grpcCons map[string]*grpcCon + lock sync.Mutex +} + +type grpcCon struct { + grpcCon *grpc.ClientConn + refCount int + stopClosing chan any +} + +func NewConnPool(cfg PoolConfig) *ConnPool { + return &ConnPool{ + cfg: cfg, + grpcCons: make(map[string]*grpcCon), + } +} + +func (p *ConnPool) GetConnection(addr string) (*grpc.ClientConn, error) { + p.lock.Lock() + defer p.lock.Unlock() + + con := p.grpcCons[addr] + if con == nil { + gcon, err := p.connecting(addr) + if err != nil { + return nil, err + } + + con = &grpcCon{ + grpcCon: gcon, + refCount: 0, + stopClosing: nil, + } + + p.grpcCons[addr] = con + } else if con.stopClosing != nil { + close(con.stopClosing) + con.stopClosing = nil + } + + con.refCount++ + + return con.grpcCon, nil +} + +func (p *ConnPool) connecting(addr string) (*grpc.ClientConn, error) { + if p.cfg.ClientCert != nil { + gcon, err := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + RootCAs: p.cfg.RootCA, + Certificates: []tls.Certificate{*p.cfg.ClientCert}, + ServerName: InternalAPISNIV1, + NextProtos: []string{"h2"}, + }))) + if err != nil { + return nil, err + } + + return gcon, nil + } + + if p.cfg.AccessTokenProvider == nil { + return nil, fmt.Errorf("no client cert or access token provider") + } + + gcon, err := grpc.NewClient(addr, + grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + RootCAs: p.cfg.RootCA, + ServerName: ClientAPISNIV1, + NextProtos: []string{"h2"}, + })), + grpc.WithUnaryInterceptor(p.populateAccessTokenUnary), + grpc.WithStreamInterceptor(p.populateAccessTokenStream), + ) + if err != nil { + return nil, err + } + + return gcon, nil +} +func (p *ConnPool) Release(addr string) { + p.lock.Lock() + defer p.lock.Unlock() + + grpcCon := p.grpcCons[addr] + if grpcCon == nil { + return + } + + grpcCon.refCount-- + grpcCon.refCount = max(grpcCon.refCount, 0) + + if grpcCon.refCount == 0 { + stopClosing := make(chan any) + grpcCon.stopClosing = stopClosing + + go func() { + select { + case <-stopClosing: + return + + case <-time.After(time.Minute): + p.lock.Lock() + defer p.lock.Unlock() + + grpcCon := p.grpcCons[addr] + if grpcCon == nil { + return + } + + if grpcCon.refCount == 0 { + grpcCon.grpcCon.Close() + delete(p.grpcCons, addr) + } + } + }() + } +} + +func (p *ConnPool) populateAccessTokenUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + authInfo, err := p.cfg.AccessTokenProvider.GetAuthInfo() + if err != nil { + return err + } + + md := metadata.Pairs( + MetaUserID, fmt.Sprintf("%v", authInfo.UserID), + MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID), + MetaNonce, authInfo.Nonce, + MetaSignature, authInfo.Signature, + ) + + ctx = metadata.NewOutgoingContext(ctx, md) + return invoker(ctx, method, req, reply, cc, opts...) +} + +func (p *ConnPool) populateAccessTokenStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + authInfo, err := p.cfg.AccessTokenProvider.GetAuthInfo() + if err != nil { + return nil, err + } + + md := metadata.Pairs( + MetaUserID, fmt.Sprintf("%v", authInfo.UserID), + MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID), + MetaNonce, authInfo.Nonce, + MetaSignature, authInfo.Signature, + ) + + ctx = metadata.NewOutgoingContext(ctx, md) + return streamer(ctx, desc, cc, method, opts...) +} diff --git a/common/pkgs/rpc/rpc.go b/common/pkgs/rpc/rpc.go new file mode 100644 index 0000000..9ab1e3e --- /dev/null +++ b/common/pkgs/rpc/rpc.go @@ -0,0 +1 @@ +package rpc diff --git a/common/pkgs/rpc/server.go b/common/pkgs/rpc/server.go index 54746f5..e70fab9 100644 --- a/common/pkgs/rpc/server.go +++ b/common/pkgs/rpc/server.go @@ -1,11 +1,16 @@ package rpc import ( + "crypto/tls" + "crypto/x509" + "fmt" "net" + "os" "gitlink.org.cn/cloudream/common/pkgs/async" "gitlink.org.cn/cloudream/common/pkgs/logger" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) type ServerEventChan = async.UnboundChannel[RPCServerEvent] @@ -20,34 +25,84 @@ type ExitEvent struct { } type Config struct { - Listen string `json:"listen"` + Listen string `json:"listen"` + RootCA string `json:"rootCA"` + ServerCert string `json:"serverCert"` + ServerKey string `json:"serverKey"` } type ServerBase struct { - cfg Config - grpcSvr *grpc.Server - srvImpl any - svcDesc *grpc.ServiceDesc + cfg Config + grpcSvr *grpc.Server + srvImpl any + svcDesc *grpc.ServiceDesc + rootCA *x509.CertPool + serverCert tls.Certificate + accessTokenAuthAPIs map[string]bool + tokenVerifier AccessTokenVerifier + noAuthAPIs map[string]bool } -func NewServerBase(cfg Config, srvImpl any, svcDesc *grpc.ServiceDesc) *ServerBase { +func NewServerBase(cfg Config, srvImpl any, svcDesc *grpc.ServiceDesc, accessTokenAuthAPIs []string, tokenVerifier AccessTokenVerifier, noAuthAPIs []string) *ServerBase { + tokenAuthAPIs := make(map[string]bool) + for _, api := range accessTokenAuthAPIs { + tokenAuthAPIs[api] = true + } + + noAuths := make(map[string]bool) + for _, api := range noAuthAPIs { + noAuths[api] = true + } + return &ServerBase{ - cfg: cfg, - srvImpl: srvImpl, - svcDesc: svcDesc, + cfg: cfg, + srvImpl: srvImpl, + svcDesc: svcDesc, + accessTokenAuthAPIs: tokenAuthAPIs, + tokenVerifier: tokenVerifier, + noAuthAPIs: noAuths, } } func (s *ServerBase) Start() *ServerEventChan { ch := async.NewUnboundChannel[RPCServerEvent]() go func() { + svrCert, err := tls.LoadX509KeyPair(s.cfg.ServerCert, s.cfg.ServerKey) + if err != nil { + logger.Warnf("load server cert: %v", err) + ch.Send(&ExitEvent{Err: err}) + return + } + s.serverCert = svrCert + + rootCA, err := os.ReadFile(s.cfg.RootCA) + if err != nil { + logger.Warnf("load root ca: %v", err) + ch.Send(&ExitEvent{Err: err}) + return + } + + s.rootCA = x509.NewCertPool() + if !s.rootCA.AppendCertsFromPEM(rootCA) { + logger.Warnf("load root ca: failed to parse root ca") + ch.Send(&ExitEvent{Err: fmt.Errorf("failed to parse root ca")}) + return + } + logger.Infof("start serving rpc at: %v", s.cfg.Listen) lis, err := net.Listen("tcp", s.cfg.Listen) if err != nil { ch.Send(&ExitEvent{Err: err}) return } - s.grpcSvr = grpc.NewServer() + + s.grpcSvr = grpc.NewServer( + grpc.Creds(credentials.NewTLS(&tls.Config{ + GetConfigForClient: s.tlsConfigSelector, + })), + grpc.UnaryInterceptor(s.authUnary), + grpc.StreamInterceptor(s.authStream), + ) s.grpcSvr.RegisterService(s.svcDesc, s.srvImpl) err = s.grpcSvr.Serve(lis) ch.Send(&ExitEvent{Err: err}) diff --git a/common/pkgs/rpc/utils.go b/common/pkgs/rpc/utils.go index 5766701..ff8fbfb 100644 --- a/common/pkgs/rpc/utils.go +++ b/common/pkgs/rpc/utils.go @@ -38,17 +38,17 @@ func UnaryClient[Resp, Req any](apiFn func(context.Context, *Request, ...grpc.Ca func UnaryServer[Resp, Req any](apiFn func(context.Context, Req) (Resp, *CodeError), ctx context.Context, req *Request) (*Response, error) { rreq, err := serder.JSONToObjectEx[Req](req.Payload) if err != nil { - return nil, makeCodeError(errorcode.OperationFailed, err.Error()) + return nil, MakeCodeError(errorcode.OperationFailed, err.Error()) } ret, cerr := apiFn(ctx, rreq) if cerr != nil { - return nil, wrapCodeError(cerr) + return nil, WrapCodeError(cerr) } data, err := serder.ObjectToJSONEx(ret) if err != nil { - return nil, makeCodeError(errorcode.OperationFailed, err.Error()) + return nil, MakeCodeError(errorcode.OperationFailed, err.Error()) } return &Response{ @@ -120,33 +120,33 @@ func UploadStreamServer[Resp any, Req UploadStreamReq, APIRet UploadStreamAPISer cr := NewChunkedReader(req) _, data, err := cr.NextDataPart() if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } _, pr, err := cr.NextPart() if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } rreq, err := serder.JSONToObjectEx[Req](data) if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } rreq.SetStream(pr) resp, cerr := apiFn(req.Context(), rreq) if cerr != nil { - return wrapCodeError(cerr) + return WrapCodeError(cerr) } respData, err := serder.ObjectToJSONEx(resp) if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } err = req.SendAndClose(&Response{Payload: respData}) if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } return nil @@ -211,33 +211,33 @@ func DownloadStreamClient[Resp DownloadStreamResp, Req any, APIRet DownloadStrea func DownloadStreamServer[Resp DownloadStreamResp, Req any, APIRet DownloadStreamAPIServer](apiFn func(context.Context, Req) (Resp, *CodeError), req *Request, ret APIRet) error { rreq, err := serder.JSONToObjectEx[Req](req.Payload) if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } resp, cerr := apiFn(ret.Context(), rreq) if cerr != nil { - return wrapCodeError(cerr) + return WrapCodeError(cerr) } cw := NewChunkedWriter(ret) data, err := serder.ObjectToJSONEx(resp) if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } err = cw.WriteDataPart("", data) if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } _, err = cw.WriteStreamPart("", resp.GetStream()) if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } err = cw.Finish() if err != nil { - return makeCodeError(errorcode.OperationFailed, err.Error()) + return MakeCodeError(errorcode.OperationFailed, err.Error()) } return nil @@ -282,12 +282,12 @@ func getCodeError(err error) *CodeError { return Failed(errorcode.OperationFailed, err.Error()) } -func makeCodeError(code string, msg string) error { +func MakeCodeError(code string, msg string) error { ce, _ := status.New(codes.Unknown, "custom error").WithDetails(Failed(code, msg)) return ce.Err() } -func wrapCodeError(ce *CodeError) error { +func WrapCodeError(ce *CodeError) error { e, _ := status.New(codes.Unknown, "custom error").WithDetails(ce) return e.Err() } diff --git a/coordinator/internal/accesstoken/accesstoken.go b/coordinator/internal/accesstoken/accesstoken.go new file mode 100644 index 0000000..5c3b785 --- /dev/null +++ b/coordinator/internal/accesstoken/accesstoken.go @@ -0,0 +1,38 @@ +package accesstoken + +import ( + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" + "gorm.io/gorm" +) + +type ExitEvent = accesstoken.ExitEvent + +type CacheKey = accesstoken.CacheKey + +type Cache struct { + *accesstoken.Cache + db *db.DB +} + +func New(db *db.DB) *Cache { + c := &Cache{ + db: db, + } + c.Cache = accesstoken.New(c.load) + + return c +} + +func (c *Cache) load(key accesstoken.CacheKey) (cortypes.UserAccessToken, error) { + token, err := c.db.UserAccessToken().GetByID(c.db.DefCtx(), key.UserID, key.TokenID) + if err == gorm.ErrRecordNotFound { + return cortypes.UserAccessToken{}, accesstoken.ErrTokenNotFound + } + if err != nil { + return cortypes.UserAccessToken{}, err + } + + return token, nil +} diff --git a/coordinator/internal/cmd/cert.go b/coordinator/internal/cmd/cert.go new file mode 100644 index 0000000..24afb98 --- /dev/null +++ b/coordinator/internal/cmd/cert.go @@ -0,0 +1,212 @@ +package cmd + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "time" + + "github.com/spf13/cobra" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" +) + +func init() { + certCmd := cobra.Command{ + Use: "cert", + } + RootCmd.AddCommand(&certCmd) + + certRoot := cobra.Command{ + Use: "root [outputDir]", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + certRoot(args[0]) + }, + } + certCmd.AddCommand(&certRoot) + + var certFilePath string + var keyFilePath string + + certServer := cobra.Command{ + Use: "server [outputDir]", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + certServer(certFilePath, keyFilePath, args[0]) + }, + } + certServer.Flags().StringVar(&certFilePath, "cert", "", "CA certificate file path") + certServer.Flags().StringVar(&keyFilePath, "key", "", "CA key file path") + certCmd.AddCommand(&certServer) + + certClient := cobra.Command{ + Use: "client [outputDir]", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + certClient(certFilePath, keyFilePath, args[0]) + }, + } + certClient.Flags().StringVar(&certFilePath, "cert", "", "CA certificate file path") + certClient.Flags().StringVar(&keyFilePath, "key", "", "CA key file path") + certCmd.AddCommand(&certClient) +} + +func certRoot(output string) { + caPriv, _ := rsa.GenerateKey(rand.Reader, 2048) + + // 创建 CA 证书模板 + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"JCS"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), // 有效期10年 + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + } + + // 自签名 CA 证书 + caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caPriv.PublicKey, caPriv) + + // 保存 CA 证书和私钥 + writePem(filepath.Join(output, "ca_cert.pem"), "CERTIFICATE", caCertDER) + writePem(filepath.Join(output, "ca_key.pem"), "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(caPriv)) + fmt.Println("CA certificate and key saved to", output) +} + +func certServer(certFile string, keyFile string, output string) { + // 读取 CA 证书和私钥 + caCertPEM, err := os.ReadFile(certFile) + if err != nil { + fmt.Println("Failed to read CA certificate:", err) + return + } + caKeyPEM, err := os.ReadFile(keyFile) + if err != nil { + fmt.Println("Failed to read CA key:", err) + return + } + caCertPEMBlock, _ := pem.Decode(caCertPEM) + if caCertPEMBlock == nil { + fmt.Println("Failed to decode CA certificate") + return + } + caKeyPEMBlock, _ := pem.Decode(caKeyPEM) + if caKeyPEMBlock == nil { + fmt.Println("Failed to decode CA key") + return + } + + caCert, err := x509.ParseCertificate(caCertPEMBlock.Bytes) + if err != nil { + fmt.Println("Failed to parse CA certificate:", err) + return + } + + caKey, err := x509.ParsePKCS1PrivateKey(caKeyPEMBlock.Bytes) + if err != nil { + fmt.Println("Failed to parse CA key:", err) + return + } + + // 生成服务端私钥 + serverPriv, _ := rsa.GenerateKey(rand.Reader, 2048) + + // 服务端证书模板 + serverTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + CommonName: "localhost", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), // 有效期1年 + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // 添加主机名/IP 到证书 + serverTemplate.DNSNames = []string{rpc.ClientAPISNIV1, rpc.InternalAPISNIV1} + + // 用 CA 签发服务端证书 + serverCertDER, _ := x509.CreateCertificate(rand.Reader, serverTemplate, caCert, &serverPriv.PublicKey, caKey) + + // 保存服务端证书和私钥 + writePem(filepath.Join(output, "server_cert.pem"), "CERTIFICATE", serverCertDER) + writePem(filepath.Join(output, "server_key.pem"), "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(serverPriv)) + fmt.Println("Server certificate and key saved to", output) +} + +func certClient(certFile string, keyFile string, output string) { + // 读取 CA 证书和私钥 + caCertPEM, err := os.ReadFile(certFile) + if err != nil { + fmt.Println("Failed to read CA certificate:", err) + return + } + caKeyPEM, err := os.ReadFile(keyFile) + if err != nil { + fmt.Println("Failed to read CA key:", err) + return + } + caCertPEMBlock, _ := pem.Decode(caCertPEM) + if caCertPEMBlock == nil { + fmt.Println("Failed to decode CA certificate") + return + } + caKeyPEMBlock, _ := pem.Decode(caKeyPEM) + if caKeyPEMBlock == nil { + fmt.Println("Failed to decode CA key") + return + } + + caCert, err := x509.ParseCertificate(caCertPEMBlock.Bytes) + if err != nil { + fmt.Println("Failed to parse CA certificate:", err) + return + } + + caKey, err := x509.ParsePKCS1PrivateKey(caKeyPEMBlock.Bytes) + if err != nil { + fmt.Println("Failed to parse CA key:", err) + return + } + + // 生成客户端私钥 + clientPriv, _ := rsa.GenerateKey(rand.Reader, 2048) + + // 客户端证书模板 + clientTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(3), + Subject: pkix.Name{ + CommonName: "client", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), // 有效期1年 + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + // 用 CA 签发客户端证书 + clientCertDER, _ := x509.CreateCertificate(rand.Reader, clientTemplate, caCert, &clientPriv.PublicKey, caKey) + + // 保存客户端证书和私钥 + writePem(filepath.Join(output, "client_cert.pem"), "CERTIFICATE", clientCertDER) + writePem(filepath.Join(output, "client_key.pem"), "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(clientPriv)) + fmt.Println("Client certificate and key saved to", output) +} + +func writePem(filename, pemType string, bytes []byte) { + f, _ := os.Create(filename) + pem.Encode(f, &pem.Block{Type: pemType, Bytes: bytes}) + f.Close() +} diff --git a/coordinator/internal/cmd/migrate.go b/coordinator/internal/cmd/migrate.go index 887dbc9..e86d3ac 100644 --- a/coordinator/internal/cmd/migrate.go +++ b/coordinator/internal/cmd/migrate.go @@ -42,6 +42,8 @@ func migrate(configPath string) { migrateOne(db, cortypes.Hub{}) migrateOne(db, cortypes.HubLocation{}) migrateOne(db, cortypes.User{}) + migrateOne(db, cortypes.UserAccessToken{}) + migrateOne(db, cortypes.LoadedAccessToken{}) fmt.Println("migrate success") } diff --git a/coordinator/internal/cmd/serve.go b/coordinator/internal/cmd/serve.go index 7b7d5e4..12f4c08 100644 --- a/coordinator/internal/cmd/serve.go +++ b/coordinator/internal/cmd/serve.go @@ -9,7 +9,7 @@ import ( stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" - hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/config" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/repl" @@ -44,7 +44,13 @@ func serve(configPath string) { os.Exit(1) } - stgglb.InitPools(&hubrpc.PoolConfig{}, nil) + hubRPCCfg, err := config.Cfg().HubRPC.Build(nil) + if err != nil { + logger.Errorf("build hub rpc config: %v", err) + os.Exit(1) + } + + stgglb.InitPools(hubRPCCfg, nil) db2, err := db.NewDB(&config.Cfg().DB) if err != nil { @@ -59,13 +65,18 @@ func serve(configPath string) { // } // go servePublisher(evtPub) + // 客户端访问令牌缓存 + accToken := accesstoken.New(db2) + accTokenChan := accToken.Start() + defer accToken.Stop() + // RPC服务 - rpcSvr := corrpc.NewServer(config.Cfg().RPC, myrpc.NewService(db2)) + rpcSvr := corrpc.NewServer(config.Cfg().RPC, myrpc.NewService(db2, accToken), accToken) rpcSvrChan := rpcSvr.Start() defer rpcSvr.Stop() // 定时任务 - tktk := ticktock.New(config.Cfg().TickTock, db2) + tktk := ticktock.New(config.Cfg().TickTock, db2, accToken) tktk.Start() defer tktk.Stop() @@ -74,11 +85,29 @@ func serve(configPath string) { replCh := rep.Start() /// 开始监听各个模块的事件 + accTokenEvt := accTokenChan.Receive() replEvt := replCh.Receive() rpcEvt := rpcSvrChan.Receive() loop: for { select { + case e := <-accTokenEvt.Chan(): + if e.Err != nil { + logger.Errorf("receive access token event: %v", e.Err) + break loop + } + + switch e := e.Value.(type) { + case accesstoken.ExitEvent: + if e.Err != nil { + logger.Errorf("access token cache exited with error: %v", e.Err) + } else { + logger.Info("access token cache exited") + } + break loop + } + accTokenEvt = accTokenChan.Receive() + case e := <-replEvt.Chan(): if e.Err != nil { logger.Errorf("receive repl event: %v", err) diff --git a/coordinator/internal/config/config.go b/coordinator/internal/config/config.go index bfb1643..7c77009 100644 --- a/coordinator/internal/config/config.go +++ b/coordinator/internal/config/config.go @@ -4,15 +4,17 @@ import ( log "gitlink.org.cn/cloudream/common/pkgs/logger" c "gitlink.org.cn/cloudream/common/utils/config" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" + hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/ticktock" ) type Config struct { - Logger log.Config `json:"logger"` - DB db.Config `json:"db"` - TickTock ticktock.Config `json:"tickTock"` - RPC rpc.Config `json:"rpc"` + Logger log.Config `json:"logger"` + DB db.Config `json:"db"` + TickTock ticktock.Config `json:"tickTock"` + RPC rpc.Config `json:"rpc"` + HubRPC hubrpc.PoolConfigJSON `json:"hubRPC"` } var cfg Config diff --git a/coordinator/internal/db/loaded_access_token.go b/coordinator/internal/db/loaded_access_token.go new file mode 100644 index 0000000..ab6a562 --- /dev/null +++ b/coordinator/internal/db/loaded_access_token.go @@ -0,0 +1,49 @@ +package db + +import ( + "time" + + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" + "gorm.io/gorm/clause" +) + +type LoadedAccessTokenDB struct { + *DB +} + +func (db *DB) LoadedAccessToken() *LoadedAccessTokenDB { + return &LoadedAccessTokenDB{DB: db} +} + +func (db *LoadedAccessTokenDB) GetByUserIDAndTokenID(ctx SQLContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID) ([]cortypes.LoadedAccessToken, error) { + var ret []cortypes.LoadedAccessToken + err := ctx.Table("LoadedAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Find(&ret).Error + return ret, err +} + +func (*LoadedAccessTokenDB) CreateOrUpdate(ctx SQLContext, token cortypes.LoadedAccessToken) error { + return ctx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "UserID"}, {Name: "TokenID"}, {Name: "HubID"}}, + DoUpdates: clause.AssignmentColumns([]string{"LoadedAt"}), + }).Create(token).Error +} + +func (*LoadedAccessTokenDB) GetExpired(ctx SQLContext, expireAt time.Time) ([]cortypes.LoadedAccessToken, error) { + var ret []cortypes.LoadedAccessToken + err := ctx.Table("LoadedAccessToken"). + Select("LoadedAccessToken.*"). + Joins("join UserAccessToken on UserAccessToken.UserID = LoadedAccessToken.UserID and UserAccessToken.TokenID = LoadedAccessToken.TokenID"). + Where("UserAccessToken.ExpiresAt < ?", expireAt). + Find(&ret).Error + return ret, err +} + +func (*LoadedAccessTokenDB) DeleteExpired(ctx SQLContext, expireAt time.Time) error { + return ctx.Table("LoadedAccessToken"). + Where("UserID in (select UserID from UserAccessToken where ExpiresAt < ?)", expireAt). + Delete(&cortypes.LoadedAccessToken{}).Error +} + +func (db *LoadedAccessTokenDB) DeleteAllByUserIDAndTokenID(ctx SQLContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID) error { + return ctx.Table("LoadedAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&cortypes.LoadedAccessToken{}).Error +} diff --git a/coordinator/internal/db/user.go b/coordinator/internal/db/user.go index 90c542c..08ceea7 100644 --- a/coordinator/internal/db/user.go +++ b/coordinator/internal/db/user.go @@ -19,14 +19,14 @@ func (db *UserDB) GetByID(ctx SQLContext, userID cortypes.UserID) (cortypes.User return ret, err } -func (db *UserDB) GetByName(ctx SQLContext, name string) (cortypes.User, error) { +func (db *UserDB) GetByAccount(ctx SQLContext, account string) (cortypes.User, error) { var ret cortypes.User - err := ctx.Table("User").Where("Name = ?", name).First(&ret).Error + err := ctx.Table("User").Where("Account = ?", account).First(&ret).Error return ret, err } -func (db *UserDB) Create(ctx SQLContext, name string) (cortypes.User, error) { - _, err := db.GetByName(ctx, name) +func (db *UserDB) Create(ctx SQLContext, account string, password string, nickName string) (cortypes.User, error) { + _, err := db.GetByAccount(ctx, account) if err == nil { return cortypes.User{}, gorm.ErrDuplicatedKey } @@ -34,7 +34,7 @@ func (db *UserDB) Create(ctx SQLContext, name string) (cortypes.User, error) { return cortypes.User{}, err } - user := cortypes.User{Name: name} + user := cortypes.User{NickName: nickName, Account: account, Password: password} err = ctx.Table("User").Create(&user).Error return user, err } diff --git a/coordinator/internal/db/user_access_token.go b/coordinator/internal/db/user_access_token.go new file mode 100644 index 0000000..a263055 --- /dev/null +++ b/coordinator/internal/db/user_access_token.go @@ -0,0 +1,33 @@ +package db + +import ( + "time" + + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" +) + +type UserAccessTokenDB struct { + *DB +} + +func (db *DB) UserAccessToken() *UserAccessTokenDB { + return &UserAccessTokenDB{DB: db} +} + +func (db *UserAccessTokenDB) GetByID(ctx SQLContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID) (cortypes.UserAccessToken, error) { + var ret cortypes.UserAccessToken + err := ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).First(&ret).Error + return ret, err +} + +func (*UserAccessTokenDB) Create(ctx SQLContext, token *cortypes.UserAccessToken) error { + return ctx.Table("UserAccessToken").Create(token).Error +} + +func (db *UserAccessTokenDB) DeleteByID(ctx SQLContext, userID cortypes.UserID, tokenID cortypes.AccessTokenID) error { + return ctx.Table("UserAccessToken").Where("UserID = ? AND TokenID = ?", userID, tokenID).Delete(&cortypes.UserAccessToken{}).Error +} + +func (*UserAccessTokenDB) DeleteExpired(ctx SQLContext, expireTime time.Time) error { + return ctx.Table("UserAccessToken").Where("ExpiresAt < ?", expireTime).Delete(&cortypes.UserAccessToken{}).Error +} diff --git a/coordinator/internal/repl/user.go b/coordinator/internal/repl/user.go new file mode 100644 index 0000000..268f9f5 --- /dev/null +++ b/coordinator/internal/repl/user.go @@ -0,0 +1,62 @@ +package repl + +import ( + "encoding/hex" + "fmt" + "os" + + "github.com/spf13/cobra" + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" + "golang.org/x/crypto/bcrypt" + "golang.org/x/term" +) + +func init() { + userCmd := &cobra.Command{ + Use: "user", + Short: "user command", + } + RootCmd.AddCommand(userCmd) + + createCmd := &cobra.Command{ + Use: "create [account] [nickName]", + Short: "create a new user account", + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + userCreate(GetCmdCtx(cmd), args[0], args[1]) + }, + } + userCmd.AddCommand(createCmd) +} + +func userCreate(ctx *CommandContext, account string, nickName string) { + _, err := ctx.repl.db.User().GetByAccount(ctx.repl.db.DefCtx(), account) + if err == nil { + fmt.Printf("user %s already exists\n", account) + return + } + + fmt.Printf("input account password: ") + pass, err := term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + fmt.Println("error reading password:", err) + return + } + + passHash, err := bcrypt.GenerateFromPassword(pass, bcrypt.DefaultCost) + if err != nil { + fmt.Println("error hashing password:", err) + return + } + + user, err := db.DoTx02(ctx.repl.db, func(tx db.SQLContext) (cortypes.User, error) { + return ctx.repl.db.User().Create(tx, account, hex.EncodeToString(passHash), nickName) + }) + if err != nil { + fmt.Println("error creating user:", err) + return + } + + fmt.Printf("user %s created\n", user.Account) +} diff --git a/coordinator/internal/rpc/service.go b/coordinator/internal/rpc/service.go index 03ecdee..c0bc216 100644 --- a/coordinator/internal/rpc/service.go +++ b/coordinator/internal/rpc/service.go @@ -1,15 +1,18 @@ package rpc import ( + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" ) type Service struct { - db *db.DB + db *db.DB + accessToken *accesstoken.Cache } -func NewService(db *db.DB) *Service { +func NewService(db *db.DB, accessToken *accesstoken.Cache) *Service { return &Service{ - db: db, + db: db, + accessToken: accessToken, } } diff --git a/coordinator/internal/rpc/user.go b/coordinator/internal/rpc/user.go new file mode 100644 index 0000000..a64a519 --- /dev/null +++ b/coordinator/internal/rpc/user.go @@ -0,0 +1,227 @@ +package rpc + +import ( + "context" + "crypto/ed25519" + "encoding/hex" + "fmt" + "time" + + "github.com/google/uuid" + "gitlink.org.cn/cloudream/common/consts/errorcode" + "gitlink.org.cn/cloudream/common/pkgs/logger" + stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" + corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" + hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +func (svc *Service) UserLogin(ctx context.Context, msg *corrpc.UserLogin) (*corrpc.UserLoginResp, *rpc.CodeError) { + log := logger.WithField("Account", msg.Account) + + user, err := svc.db.User().GetByAccount(svc.db.DefCtx(), msg.Account) + if err != nil { + if err == gorm.ErrRecordNotFound { + log.Warnf("account not found") + return nil, rpc.Failed(errorcode.DataNotFound, "account not found") + } + + log.Warnf("getting account: %v", err) + return nil, rpc.Failed(errorcode.OperationFailed, "getting account: %v", err) + } + + dbPass, err := hex.DecodeString(user.Password) + if err != nil { + log.Warnf("decoding password: %v", err) + return nil, rpc.Failed(errorcode.OperationFailed, "decoding password: %v", err) + } + + if bcrypt.CompareHashAndPassword(dbPass, []byte(msg.Password)) != nil { + log.Warnf("password not match") + return nil, rpc.Failed(errorcode.Unauthorized, "password not match") + } + + pubKey, priKey, err := ed25519.GenerateKey(nil) + if err != nil { + log.Warnf("generating key: %v", err) + return nil, rpc.Failed(errorcode.OperationFailed, "generating key: %v", err) + } + + pubKeyStr := hex.EncodeToString(pubKey) + nowTime := time.Now() + token := cortypes.UserAccessToken{ + UserID: user.UserID, + TokenID: cortypes.AccessTokenID(uuid.NewString()), + PublicKey: pubKeyStr, + ExpiresAt: nowTime.Add(time.Hour), + CreatedAt: nowTime, + } + + err = svc.db.UserAccessToken().Create(svc.db.DefCtx(), &token) + if err != nil { + log.Warnf("creating token: %v", err) + return nil, rpc.Failed(errorcode.OperationFailed, "creating token: %v", err) + } + + log.Infof("login success, token expires at %v", token.ExpiresAt) + + return &corrpc.UserLoginResp{ + Token: token, + PrivateKey: hex.EncodeToString(priKey), + }, nil +} + +func (svc *Service) UserRefreshToken(ctx context.Context, msg *corrpc.UserRefreshToken) (*corrpc.UserRefreshTokenResp, *rpc.CodeError) { + authInfo, ok := rpc.GetAuthInfo(ctx) + if !ok { + return nil, rpc.Failed(errorcode.Unauthorized, "unauthorized") + } + + log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) + + pubKey, priKey, err := ed25519.GenerateKey(nil) + if err != nil { + log.Warnf("generating key: %v", err) + return nil, rpc.Failed(errorcode.OperationFailed, "generating key: %v", err) + } + + pubKeyStr := hex.EncodeToString(pubKey) + nowTime := time.Now() + token := cortypes.UserAccessToken{ + UserID: authInfo.UserID, + TokenID: cortypes.AccessTokenID(uuid.NewString()), + PublicKey: pubKeyStr, + ExpiresAt: nowTime.Add(time.Hour), + CreatedAt: nowTime, + } + + err = svc.db.UserAccessToken().Create(svc.db.DefCtx(), &token) + if err != nil { + log.Warnf("creating token: %v", err) + return nil, rpc.Failed(errorcode.OperationFailed, "creating token: %v", err) + } + + log.Infof("refresh token success, new token expires at %v", token.ExpiresAt) + + return &corrpc.UserRefreshTokenResp{ + Token: token, + PrivateKey: hex.EncodeToString(priKey), + }, nil +} + +func (svc *Service) UserLogout(ctx context.Context, msg *corrpc.UserLogout) (*corrpc.UserLogoutResp, *rpc.CodeError) { + authInfo, ok := rpc.GetAuthInfo(ctx) + if !ok { + return nil, rpc.Failed(errorcode.Unauthorized, "unauthorized") + } + + log := logger.WithField("UserID", authInfo.UserID).WithField("TokenID", authInfo.AccessTokenID) + + loaded, err := db.DoTx02(svc.db, func(tx db.SQLContext) ([]cortypes.LoadedAccessToken, error) { + token, err := svc.db.UserAccessToken().GetByID(tx, authInfo.UserID, authInfo.AccessTokenID) + if err != nil { + return nil, err + } + + err = svc.db.UserAccessToken().DeleteByID(tx, token.UserID, token.TokenID) + if err != nil { + return nil, err + } + + loaded, err := svc.db.LoadedAccessToken().GetByUserIDAndTokenID(tx, token.UserID, token.TokenID) + if err != nil { + return nil, err + } + + err = svc.db.LoadedAccessToken().DeleteAllByUserIDAndTokenID(tx, token.UserID, token.TokenID) + if err != nil { + return nil, err + } + + return loaded, nil + }) + if err != nil { + log.Warnf("delete access token: %v", err) + if err == gorm.ErrRecordNotFound { + return nil, rpc.Failed(errorcode.DataNotFound, "token not found") + } + + return nil, rpc.Failed(errorcode.OperationFailed, "delete access token: %v", err) + } + + svc.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ + UserID: authInfo.UserID, + TokenID: authInfo.AccessTokenID, + }) + + var loadedHubIDs []cortypes.HubID + for _, l := range loaded { + loadedHubIDs = append(loadedHubIDs, l.HubID) + } + + svc.notifyLoadedHubs(authInfo.UserID, authInfo.AccessTokenID, loadedHubIDs) + + return &corrpc.UserLogoutResp{}, nil +} + +func (svc *Service) notifyLoadedHubs(userID cortypes.UserID, tokenID cortypes.AccessTokenID, loadedHubIDs []cortypes.HubID) { + log := logger.WithField("UserID", userID).WithField("TokenID", tokenID) + + loadedHubs, err := svc.db.Hub().BatchGetByID(svc.db.DefCtx(), loadedHubIDs) + if err != nil { + log.Warnf("getting hubs: %v", err) + return + } + + for _, l := range loadedHubs { + addr, ok := l.Address.(*cortypes.GRPCAddressInfo) + if !ok { + continue + } + + cli := stgglb.HubRPCPool.Get(addr.ExternalIP, addr.ExternalGRPCPort) + // 不关心返回值 + cli.NotifyUserAccessTokenInvalid(context.Background(), &hubrpc.NotifyUserAccessTokenInvalid{ + UserID: userID, + TokenID: tokenID, + }) + cli.Release() + } +} + +func (svc *Service) HubLoadAccessToken(ctx context.Context, msg *corrpc.HubLoadAccessToken) (*corrpc.HubLoadAccessTokenResp, *rpc.CodeError) { + token, err := db.DoTx02(svc.db, func(tx db.SQLContext) (cortypes.UserAccessToken, error) { + token, err := svc.db.UserAccessToken().GetByID(tx, msg.UserID, msg.TokenID) + if err != nil { + return cortypes.UserAccessToken{}, err + } + + err = svc.db.LoadedAccessToken().CreateOrUpdate(tx, cortypes.LoadedAccessToken{ + UserID: msg.UserID, + TokenID: msg.TokenID, + HubID: msg.HubID, + LoadedAt: time.Now(), + }) + if err != nil { + return cortypes.UserAccessToken{}, fmt.Errorf("creating access token loaded record: %v", err) + } + + return token, nil + }) + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, rpc.Failed(errorcode.DataNotFound, "token not found") + } + + return nil, rpc.Failed(errorcode.OperationFailed, "loading access token: %v", err) + } + + return &corrpc.HubLoadAccessTokenResp{ + Token: token, + }, nil +} diff --git a/coordinator/internal/ticktock/clear_expired_access_token.go b/coordinator/internal/ticktock/clear_expired_access_token.go new file mode 100644 index 0000000..0aaa802 --- /dev/null +++ b/coordinator/internal/ticktock/clear_expired_access_token.go @@ -0,0 +1,112 @@ +package ticktock + +import ( + "context" + "fmt" + "time" + + "gitlink.org.cn/cloudream/common/pkgs/logger" + "gitlink.org.cn/cloudream/common/utils/reflect2" + stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" + hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" + + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" +) + +type ClearExpiredAccessToken struct { +} + +func (j *ClearExpiredAccessToken) Name() string { + return reflect2.TypeNameOf[ClearExpiredAccessToken]() +} + +func (j *ClearExpiredAccessToken) Execute(t *TickTock) { + log := logger.WithType[ClearExpiredAccessToken]("TickTock") + log.Infof("job start") + startTime := time.Now() + defer func() { + log.Infof("job end, time: %v", time.Since(startTime)) + }() + + expired, err := db.DoTx02(t.db, func(tx db.SQLContext) ([]cortypes.LoadedAccessToken, error) { + nowTime := time.Now() + expired, err := t.db.LoadedAccessToken().GetExpired(tx, nowTime) + if err != nil { + return nil, fmt.Errorf("get expired access token load record: %w", err) + } + + err = t.db.LoadedAccessToken().DeleteExpired(tx, nowTime) + if err != nil { + return nil, fmt.Errorf("delete expired access token load record: %w", err) + } + + err = t.db.UserAccessToken().DeleteExpired(tx, nowTime) + if err != nil { + return nil, fmt.Errorf("delete expired user access token: %w", err) + } + + return expired, nil + }) + if err != nil { + log.Warn(err.Error()) + return + } + + uniToken := make(map[accesstoken.CacheKey]bool) + for _, t := range expired { + uniToken[accesstoken.CacheKey{ + UserID: t.UserID, + TokenID: t.TokenID, + }] = true + } + + log.Infof("%v expired access token cleared", len(uniToken)) + + // 通知本服务的AccessToken缓存失效 + for k := range uniToken { + t.accessToken.NotifyTokenInvalid(k) + } + + // 通知所有加载了失效Token的Hub + + var loadedHubIDs []cortypes.HubID + for _, e := range expired { + loadedHubIDs = append(loadedHubIDs, e.HubID) + } + + loadedHubs, err := t.db.Hub().BatchGetByID(t.db.DefCtx(), loadedHubIDs) + if err != nil { + log.Warnf("getting hubs: %v", err) + return + } + + hubMap := make(map[cortypes.HubID]cortypes.Hub) + for _, h := range loadedHubs { + hubMap[h.HubID] = h + } + + for _, e := range expired { + h, ok := hubMap[e.HubID] + if !ok { + continue + } + addr, ok := h.Address.(*cortypes.GRPCAddressInfo) + if !ok { + continue + } + + cli := stgglb.HubRPCPool.Get(addr.ExternalIP, addr.ExternalGRPCPort) + // 不关心返回值 + _, err := cli.NotifyUserAccessTokenInvalid(context.Background(), &hubrpc.NotifyUserAccessTokenInvalid{ + UserID: e.UserID, + TokenID: e.TokenID, + }) + if err != nil { + log.Warnf("notify hub %v: %v", h.HubID, err) + } + cli.Release() + } + +} diff --git a/coordinator/internal/ticktock/ticktock.go b/coordinator/internal/ticktock/ticktock.go index f63fbf0..c4e21e4 100644 --- a/coordinator/internal/ticktock/ticktock.go +++ b/coordinator/internal/ticktock/ticktock.go @@ -6,6 +6,7 @@ import ( "github.com/go-co-op/gocron/v2" "gitlink.org.cn/cloudream/common/pkgs/logger" + "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/coordinator/internal/db" ) @@ -20,19 +21,21 @@ type cronJob struct { } type TickTock struct { - cfg Config - sch gocron.Scheduler - jobs map[string]cronJob - db *db.DB + cfg Config + sch gocron.Scheduler + jobs map[string]cronJob + db *db.DB + accessToken *accesstoken.Cache } -func New(cfg Config, db *db.DB) *TickTock { +func New(cfg Config, db *db.DB, accessToken *accesstoken.Cache) *TickTock { sch, _ := gocron.NewScheduler() t := &TickTock{ - cfg: cfg, - sch: sch, - jobs: map[string]cronJob{}, - db: db, + cfg: cfg, + sch: sch, + jobs: map[string]cronJob{}, + db: db, + accessToken: accessToken, } t.initJobs() return t @@ -70,4 +73,6 @@ func (t *TickTock) addJob(job Job, duration gocron.JobDefinition) { func (t *TickTock) initJobs() { t.addJob(&CheckHubState{}, gocron.DurationJob(time.Minute*5)) + + t.addJob(&ClearExpiredAccessToken{}, gocron.DurationJob(time.Minute*5)) } diff --git a/coordinator/types/types.go b/coordinator/types/types.go index a83b2cb..a1117e3 100644 --- a/coordinator/types/types.go +++ b/coordinator/types/types.go @@ -68,15 +68,6 @@ func (HubConnectivity) TableName() string { return "HubConnectivity" } -type User struct { - UserID UserID `gorm:"column:UserID; primaryKey; type:bigint; autoIncrement" json:"userID"` - Name string `gorm:"column:Name; type:varchar(255); not null" json:"name"` -} - -func (User) TableName() string { - return "User" -} - type HubLocation struct { HubID HubID `gorm:"column:HubID; type:bigint" json:"hubID"` StorageName string `gorm:"column:StorageName; type:varchar(255); not null" json:"storageName"` @@ -86,3 +77,40 @@ type HubLocation struct { func (HubLocation) TableName() string { return "HubLocation" } + +type User struct { + UserID UserID `gorm:"column:UserID; primaryKey; type:bigint; autoIncrement" json:"userID"` + NickName string `gorm:"column:NickName; type:varchar(255); not null" json:"nickName"` + Account string `gorm:"column:Account; type:varchar(255); not null" json:"account"` + // bcrypt哈希过的密码,带有盐值 + Password string `gorm:"column:Password; type:varchar(255); not null" json:"password"` +} + +func (User) TableName() string { + return "User" +} + +type AccessTokenID string + +type UserAccessToken struct { + UserID UserID `gorm:"column:UserID; primaryKey; type:bigint" json:"userID"` + TokenID AccessTokenID `gorm:"column:TokenID; primaryKey; type:char(36); not null" json:"tokenID"` + PublicKey string `gorm:"column:PublicKey; type:char(64); not null" json:"publicKey"` + ExpiresAt time.Time `gorm:"column:ExpiresAt; type:datetime" json:"expiresAt"` + CreatedAt time.Time `gorm:"column:CreatedAt; type:datetime" json:"createdAt"` +} + +func (UserAccessToken) TableName() string { + return "UserAccessToken" +} + +type LoadedAccessToken struct { + UserID UserID `gorm:"column:UserID; primaryKey; type:bigint" json:"userID"` + TokenID AccessTokenID `gorm:"column:TokenID; primaryKey; type:char(36); not null" json:"tokenID"` + HubID HubID `gorm:"column:HubID; primaryKey; type:bigint" json:"hubID"` + LoadedAt time.Time `gorm:"column:LoadedAt; type:datetime" json:"loadedAt"` +} + +func (LoadedAccessToken) TableName() string { + return "LoadedAccessToken" +} diff --git a/go.mod b/go.mod index e78cd4d..3b29155 100644 --- a/go.mod +++ b/go.mod @@ -28,10 +28,11 @@ require ( github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.10.0 gitlink.org.cn/cloudream/common v0.0.0 + golang.org/x/crypto v0.38.0 golang.org/x/net v0.35.0 - golang.org/x/sync v0.13.0 - golang.org/x/sys v0.32.0 - golang.org/x/term v0.31.0 + golang.org/x/sync v0.14.0 + golang.org/x/sys v0.33.0 + golang.org/x/term v0.32.0 google.golang.org/grpc v1.67.1 google.golang.org/protobuf v1.36.6 gorm.io/gorm v1.25.12 @@ -70,9 +71,8 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect go.mongodb.org/mongo-driver v1.12.0 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.37.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/text v0.25.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 47a0842..62a1875 100644 --- a/go.sum +++ b/go.sum @@ -243,8 +243,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= -golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= @@ -275,8 +275,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= -golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -300,16 +300,16 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= -golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= +golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -319,8 +319,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/hub/internal/accesstoken/accesstoken.go b/hub/internal/accesstoken/accesstoken.go new file mode 100644 index 0000000..ff82c7e --- /dev/null +++ b/hub/internal/accesstoken/accesstoken.go @@ -0,0 +1,49 @@ +package accesstoken + +import ( + "context" + + "gitlink.org.cn/cloudream/common/consts/errorcode" + stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" + corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" + cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" +) + +type ExitEvent = accesstoken.ExitEvent + +type CacheKey = accesstoken.CacheKey + +type Cache struct { + localHubID cortypes.HubID + *accesstoken.Cache +} + +func New(localHubID cortypes.HubID) *Cache { + c := &Cache{ + localHubID: localHubID, + } + c.Cache = accesstoken.New(c.load) + + return c +} + +func (c *Cache) load(key accesstoken.CacheKey) (cortypes.UserAccessToken, error) { + corCli := stgglb.CoordinatorRPCPool.Get() + defer corCli.Release() + + tokenResp, cerr := corCli.HubLoadAccessToken(context.Background(), &corrpc.HubLoadAccessToken{ + UserID: key.UserID, + TokenID: key.TokenID, + HubID: c.localHubID, + }) + if cerr != nil { + if cerr.Code == errorcode.DataNotFound { + return cortypes.UserAccessToken{}, accesstoken.ErrTokenNotFound + } + + return cortypes.UserAccessToken{}, cerr.ToError() + } + + return tokenResp.Token, nil +} diff --git a/hub/internal/cmd/serve.go b/hub/internal/cmd/serve.go index 4b7bc4e..a46b62f 100644 --- a/hub/internal/cmd/serve.go +++ b/hub/internal/cmd/serve.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/cobra" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/storage/pool" + "gitlink.org.cn/cloudream/jcs-pub/hub/internal/accesstoken" "gitlink.org.cn/cloudream/jcs-pub/hub/internal/http" myrpc "gitlink.org.cn/cloudream/jcs-pub/hub/internal/rpc" @@ -22,7 +23,7 @@ import ( "gitlink.org.cn/cloudream/jcs-pub/hub/internal/config" "gitlink.org.cn/cloudream/jcs-pub/hub/internal/ticktock" - coormq "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" + corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" ) func init() { @@ -47,22 +48,36 @@ type serveOptions struct { } func serve(configPath string, opts serveOptions) { + // 加载服务配置 err := config.Init(configPath) if err != nil { fmt.Printf("init config failed, err: %s", err.Error()) os.Exit(1) } + // 初始化日志 err = logger.Init(&config.Cfg().Logger) if err != nil { fmt.Printf("init logger failed, err: %s", err.Error()) os.Exit(1) } + // 初始化全局变量 stgglb.InitLocal(config.Cfg().Local) - stgglb.InitPools(&hubrpc.PoolConfig{}, &config.Cfg().CoordinatorRPC) - // stgglb.Stats.SetupHubStorageTransfer(*config.Cfg().Local.HubID) - // stgglb.Stats.SetupHubTransfer(*config.Cfg().Local.HubID) + + // 初始化各服务客户端的连接池 + corRPCCfg, err := config.Cfg().CoordinatorRPC.Build(nil) + if err != nil { + logger.Errorf("building coordinator rpc config: %v", err) + os.Exit(1) + } + hubRPCCfg, err := config.Cfg().HubRPC.Build(nil) + if err != nil { + logger.Errorf("building hub rpc config: %v", err) + os.Exit(1) + } + stgglb.InitPools(hubRPCCfg, corRPCCfg) + // 获取Hub配置 hubCfg := downloadHubConfig() @@ -109,13 +124,19 @@ func serve(configPath string, opts serveOptions) { tktk.Start() defer tktk.Stop() + // 客户端访问令牌管理器 + accToken := accesstoken.New(config.Cfg().ID) + accTokenChan := accToken.Start() + defer accToken.Stop() + // RPC服务 - rpcSvr := hubrpc.NewServer(config.Cfg().RPC, myrpc.NewService(&worker, stgPool)) + rpcSvr := hubrpc.NewServer(config.Cfg().RPC, myrpc.NewService(&worker, stgPool, accToken), accToken) rpcSvrChan := rpcSvr.Start() defer rpcSvr.Stop() /// 开始监听各个模块的事件 evtPubEvt := evtPubChan.Receive() + accTokenEvt := accTokenChan.Receive() rpcEvt := rpcSvrChan.Receive() httpEvt := httpChan.Receive() @@ -145,6 +166,23 @@ loop: } evtPubEvt = evtPubChan.Receive() + case e := <-accTokenEvt.Chan(): + if e.Err != nil { + logger.Errorf("receive access token event: %v", err) + break loop + } + + switch e := e.Value.(type) { + case accesstoken.ExitEvent: + if e.Err != nil { + logger.Errorf("access token manager exited with error: %v", e.Err) + } else { + logger.Info("access token manager exited") + } + break loop + } + accTokenEvt = accTokenChan.Receive() + case e := <-rpcEvt.Chan(): if e.Err != nil { logger.Errorf("receive rpc event: %v", e.Err) @@ -179,14 +217,14 @@ loop: } -func downloadHubConfig() coormq.GetHubConfigResp { +func downloadHubConfig() corrpc.GetHubConfigResp { coorCli := stgglb.CoordinatorRPCPool.Get() defer coorCli.Release() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - cfgResp, cerr := coorCli.GetHubConfig(ctx, coormq.ReqGetHubConfig(cortypes.HubID(config.Cfg().ID))) + cfgResp, cerr := coorCli.GetHubConfig(ctx, corrpc.ReqGetHubConfig(cortypes.HubID(config.Cfg().ID))) if cerr != nil { logger.Errorf("getting hub config: %v", cerr) os.Exit(1) diff --git a/hub/internal/config/config.go b/hub/internal/config/config.go index bb8121d..6de9797 100644 --- a/hub/internal/config/config.go +++ b/hub/internal/config/config.go @@ -6,6 +6,7 @@ import ( stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" + hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/sysevent" cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types" "gitlink.org.cn/cloudream/jcs-pub/hub/internal/http" @@ -17,7 +18,8 @@ type Config struct { Local stgglb.LocalMachineInfo `json:"local"` RPC rpc.Config `json:"rpc"` HTTP *http.Config `json:"http"` - CoordinatorRPC corrpc.PoolConfig `json:"coordinatorRPC"` + CoordinatorRPC corrpc.PoolConfigJSON `json:"coordinatorRPC"` + HubRPC hubrpc.PoolConfigJSON `json:"hubRPC"` Logger log.Config `json:"logger"` SysEvent sysevent.Config `json:"sysEvent"` TickTock ticktock.Config `json:"tickTock"` diff --git a/hub/internal/rpc/rpc.go b/hub/internal/rpc/rpc.go index 72ede2d..26ab330 100644 --- a/hub/internal/rpc/rpc.go +++ b/hub/internal/rpc/rpc.go @@ -4,17 +4,20 @@ import ( "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/storage/pool" + "gitlink.org.cn/cloudream/jcs-pub/hub/internal/accesstoken" ) type Service struct { - swWorker *exec.Worker - stgPool *pool.Pool + swWorker *exec.Worker + stgPool *pool.Pool + accessToken *accesstoken.Cache } -func NewService(swWorker *exec.Worker, stgPool *pool.Pool) *Service { +func NewService(swWorker *exec.Worker, stgPool *pool.Pool, accessToken *accesstoken.Cache) *Service { return &Service{ - swWorker: swWorker, - stgPool: stgPool, + swWorker: swWorker, + stgPool: stgPool, + accessToken: accessToken, } } diff --git a/hub/internal/rpc/user.go b/hub/internal/rpc/user.go new file mode 100644 index 0000000..4078a1b --- /dev/null +++ b/hub/internal/rpc/user.go @@ -0,0 +1,19 @@ +package rpc + +import ( + "context" + + "gitlink.org.cn/cloudream/common/pkgs/logger" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc" + hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" +) + +func (s *Service) NotifyUserAccessTokenInvalid(ctx context.Context, msg *hubrpc.NotifyUserAccessTokenInvalid) (*hubrpc.NotifyUserAccessTokenInvalidResp, *rpc.CodeError) { + s.accessToken.NotifyTokenInvalid(accesstoken.CacheKey{ + UserID: msg.UserID, + TokenID: msg.TokenID, + }) + logger.WithField("UserID", msg.UserID).WithField("TokenID", msg.TokenID).Infof("user access token invalid") + return &hubrpc.NotifyUserAccessTokenInvalidResp{}, nil +}