JCS-pub/common/pkgs/distlock/service.go

195 lines
4.0 KiB
Go

package distlock
import (
"context"
"fmt"
"sync"
"time"
"gitlink.org.cn/cloudream/common/pkgs/future"
"gitlink.org.cn/cloudream/common/pkgs/trie"
"gitlink.org.cn/cloudream/common/utils/lo2"
"gitlink.org.cn/cloudream/jcs-pub/common/pkgs/distlock/lockprovider"
"gitlink.org.cn/cloudream/jcs-pub/common/pkgs/distlock/types"
)
type AcquireOption struct {
Timeout time.Duration
}
type AcquireOptionFn func(opt *AcquireOption)
func WithTimeout(timeout time.Duration) AcquireOptionFn {
return func(opt *AcquireOption) {
opt.Timeout = timeout
}
}
type Service struct {
lock *sync.Mutex
provdersTrie *trie.Trie[types.LockProvider]
acquirings []*acquireInfo
nextReqID int64
}
func NewService() *Service {
svc := &Service{
lock: &sync.Mutex{},
provdersTrie: trie.NewTrie[types.LockProvider](),
}
svc.provdersTrie.Create([]any{lockprovider.ShardStoreLockPathPrefix, trie.WORD_ANY}).Value = lockprovider.NewShardStoreLock()
return svc
}
type acquireInfo struct {
Request types.LockRequest
Callback *future.SetValueFuture[types.RequestID]
LastErr error
}
func (svc *Service) Acquire(req types.LockRequest, opts ...AcquireOptionFn) (*Mutex, error) {
var opt = AcquireOption{
Timeout: time.Second * 10,
}
for _, fn := range opts {
fn(&opt)
}
ctx := context.Background()
if opt.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, opt.Timeout)
defer cancel()
}
// 就地检测锁是否可用
svc.lock.Lock()
defer svc.lock.Unlock()
reqID, err := svc.tryAcquireOne(req)
if err != nil {
return nil, err
}
if reqID != "" {
return &Mutex{
svc: svc,
lockReq: req,
lockReqID: reqID,
}, nil
}
// 就地检测失败,那么就需要异步等待锁可用
info := &acquireInfo{
Request: req,
Callback: future.NewSetValue[types.RequestID](),
}
svc.acquirings = append(svc.acquirings, info)
// 等待的时候不加锁
svc.lock.Unlock()
reqID, err = info.Callback.Wait(ctx)
svc.lock.Lock()
if err == nil {
return &Mutex{
svc: svc,
lockReq: req,
lockReqID: reqID,
}, nil
}
if err != future.ErrCanceled {
lo2.Remove(svc.acquirings, info)
return nil, err
}
// 如果第一次等待是超时错误,那么在锁里再尝试获取一次结果
reqID, err = info.Callback.TryGetValue()
if err == nil {
return &Mutex{
svc: svc,
lockReq: req,
lockReqID: reqID,
}, nil
}
lo2.Remove(svc.acquirings, info)
return nil, err
}
func (s *Service) BeginReentrant() *Reentrant {
return &Reentrant{
svc: s,
}
}
func (s *Service) release(reqID types.RequestID, req types.LockRequest) {
s.lock.Lock()
defer s.lock.Unlock()
s.releaseRequest(reqID, req)
s.tryAcquirings()
}
func (a *Service) tryAcquirings() {
for i := 0; i < len(a.acquirings); i++ {
req := a.acquirings[i]
reqID, err := a.tryAcquireOne(req.Request)
if err != nil {
req.LastErr = err
continue
}
req.Callback.SetValue(reqID)
a.acquirings[i] = nil
}
a.acquirings = lo2.RemoveAllDefault(a.acquirings)
}
func (s *Service) tryAcquireOne(req types.LockRequest) (types.RequestID, error) {
err := s.testOneRequest(req)
if err != nil {
return "", err
}
reqID := types.RequestID(fmt.Sprintf("%d", s.nextReqID))
s.nextReqID++
s.applyRequest(reqID, req)
return reqID, nil
}
func (s *Service) testOneRequest(req types.LockRequest) error {
for _, lock := range req.Locks {
n, ok := s.provdersTrie.WalkEnd(lock.Path)
if !ok || n.Value == nil {
return fmt.Errorf("lock provider not found for path %v", lock.Path)
}
err := n.Value.CanLock(lock)
if err != nil {
return err
}
}
return nil
}
func (s *Service) applyRequest(reqID types.RequestID, req types.LockRequest) {
for _, lock := range req.Locks {
p, _ := s.provdersTrie.WalkEnd(lock.Path)
p.Value.Lock(reqID, lock)
}
}
func (s *Service) releaseRequest(reqID types.RequestID, req types.LockRequest) {
for _, lock := range req.Locks {
p, _ := s.provdersTrie.WalkEnd(lock.Path)
p.Value.Unlock(reqID, lock)
}
}