176 lines
3.7 KiB
Go
176 lines
3.7 KiB
Go
package accesstoken
|
||
|
||
import (
|
||
"context"
|
||
"crypto/ed25519"
|
||
"crypto/rand"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"sync"
|
||
"time"
|
||
|
||
"gitlink.org.cn/cloudream/common/pkgs/async"
|
||
"gitlink.org.cn/cloudream/common/pkgs/logger"
|
||
stgglb "gitlink.org.cn/cloudream/jcs-pub/common/globals"
|
||
"gitlink.org.cn/cloudream/jcs-pub/common/pkgs/accesstoken"
|
||
"gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc"
|
||
corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator"
|
||
cortypes "gitlink.org.cn/cloudream/jcs-pub/coordinator/types"
|
||
)
|
||
|
||
type KeeperEvent interface {
|
||
IsAccessTokenKeeper() bool
|
||
}
|
||
|
||
type ExitEvent struct {
|
||
KeeperEvent
|
||
Err error
|
||
}
|
||
|
||
type Keeper struct {
|
||
cfg Config
|
||
enabled bool
|
||
token cortypes.UserAccessToken
|
||
priKey ed25519.PrivateKey
|
||
lock sync.RWMutex
|
||
done chan any
|
||
}
|
||
|
||
func New(cfg Config, tempCli *corrpc.TempClient) (*Keeper, error) {
|
||
loginResp, cerr := tempCli.UserLogin(context.Background(), &corrpc.UserLogin{
|
||
Account: cfg.Account,
|
||
Password: cfg.Password,
|
||
})
|
||
if cerr != nil {
|
||
return nil, fmt.Errorf("login: %w", cerr.ToError())
|
||
}
|
||
|
||
priKey, err := hex.DecodeString(loginResp.PrivateKey)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("decode private key: %w", err)
|
||
}
|
||
|
||
return &Keeper{
|
||
cfg: cfg,
|
||
enabled: true,
|
||
token: loginResp.Token,
|
||
priKey: priKey,
|
||
done: make(chan any, 1),
|
||
}, nil
|
||
}
|
||
|
||
func NewDisabled() *Keeper {
|
||
return &Keeper{
|
||
done: make(chan any, 1),
|
||
enabled: false,
|
||
}
|
||
}
|
||
|
||
func (k *Keeper) Start() *async.UnboundChannel[KeeperEvent] {
|
||
log := logger.WithField("Mod", "Keeper")
|
||
|
||
ch := async.NewUnboundChannel[KeeperEvent]()
|
||
|
||
go func() {
|
||
if !k.enabled {
|
||
return
|
||
}
|
||
|
||
k.lock.RLock()
|
||
log.Infof("login success, token expires at %v", k.token.ExpiresAt)
|
||
k.lock.RUnlock()
|
||
|
||
ticker := time.NewTicker(time.Minute)
|
||
defer ticker.Stop()
|
||
|
||
loop:
|
||
for {
|
||
select {
|
||
case <-k.done:
|
||
break loop
|
||
|
||
case <-ticker.C:
|
||
k.lock.RLock()
|
||
token := k.token
|
||
k.lock.RUnlock()
|
||
|
||
// 当前Token已经过期,说明之前的刷新都失败了,打个日志
|
||
if time.Now().After(token.ExpiresAt) {
|
||
log.Warnf("token expired at %v !", token.ExpiresAt)
|
||
}
|
||
|
||
// 在Token到期前5分钟时就要开始刷新Token
|
||
|
||
tokenDeadline := token.ExpiresAt.Add(-time.Minute * 5)
|
||
if time.Now().Before(tokenDeadline) {
|
||
continue
|
||
}
|
||
|
||
corCli := stgglb.CoordinatorRPCPool.Get()
|
||
refResp, cerr := corCli.UserRefreshToken(context.Background(), &corrpc.UserRefreshToken{})
|
||
if cerr != nil {
|
||
log.Warnf("refresh token: %v", cerr)
|
||
corCli.Release()
|
||
continue
|
||
}
|
||
|
||
priKey, err := hex.DecodeString(refResp.PrivateKey)
|
||
if err != nil {
|
||
log.Warnf("decode private key: %v", err)
|
||
corCli.Release()
|
||
continue
|
||
}
|
||
|
||
log.Infof("refresh token success, new token expires at %v", refResp.Token.ExpiresAt)
|
||
|
||
k.lock.Lock()
|
||
k.token = refResp.Token
|
||
k.priKey = priKey
|
||
k.lock.Unlock()
|
||
|
||
corCli.Release()
|
||
}
|
||
}
|
||
ch.Send(ExitEvent{})
|
||
}()
|
||
|
||
return ch
|
||
}
|
||
|
||
func (k *Keeper) Stop() {
|
||
select {
|
||
case k.done <- true:
|
||
default:
|
||
}
|
||
}
|
||
|
||
func (k *Keeper) GetAuthInfo() (rpc.AccessTokenAuthInfo, error) {
|
||
if !k.enabled {
|
||
return rpc.AccessTokenAuthInfo{}, fmt.Errorf("function disabled")
|
||
}
|
||
|
||
k.lock.RLock()
|
||
token := k.token
|
||
k.lock.RUnlock()
|
||
|
||
bytes := make([]byte, 8)
|
||
|
||
_, err := rand.Read(bytes)
|
||
if err != nil {
|
||
return rpc.AccessTokenAuthInfo{}, fmt.Errorf("generate nonce: %w", err)
|
||
}
|
||
|
||
nonce := hex.EncodeToString(bytes)
|
||
stringToSign := accesstoken.MakeStringToSign(token.UserID, token.TokenID, nonce)
|
||
|
||
signBytes := ed25519.Sign(k.priKey, []byte(stringToSign))
|
||
signature := hex.EncodeToString(signBytes)
|
||
|
||
return rpc.AccessTokenAuthInfo{
|
||
UserID: token.UserID,
|
||
AccessTokenID: token.TokenID,
|
||
Nonce: nonce,
|
||
Signature: signature,
|
||
}, nil
|
||
}
|