195 lines
4.0 KiB
Go
195 lines
4.0 KiB
Go
package distlock
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"gitlink.org.cn/cloudream/common/pkgs/future"
|
|
"gitlink.org.cn/cloudream/common/pkgs/trie"
|
|
"gitlink.org.cn/cloudream/common/utils/lo2"
|
|
"gitlink.org.cn/cloudream/jcs-pub/common/pkgs/distlock/lockprovider"
|
|
"gitlink.org.cn/cloudream/jcs-pub/common/pkgs/distlock/types"
|
|
)
|
|
|
|
type AcquireOption struct {
|
|
Timeout time.Duration
|
|
}
|
|
|
|
type AcquireOptionFn func(opt *AcquireOption)
|
|
|
|
func WithTimeout(timeout time.Duration) AcquireOptionFn {
|
|
return func(opt *AcquireOption) {
|
|
opt.Timeout = timeout
|
|
}
|
|
}
|
|
|
|
type Service struct {
|
|
lock *sync.Mutex
|
|
provdersTrie *trie.Trie[types.LockProvider]
|
|
acquirings []*acquireInfo
|
|
nextReqID int64
|
|
}
|
|
|
|
func NewService() *Service {
|
|
svc := &Service{
|
|
lock: &sync.Mutex{},
|
|
provdersTrie: trie.NewTrie[types.LockProvider](),
|
|
}
|
|
|
|
svc.provdersTrie.Create([]any{lockprovider.ShardStoreLockPathPrefix, trie.WORD_ANY}).Value = lockprovider.NewShardStoreLock()
|
|
return svc
|
|
}
|
|
|
|
type acquireInfo struct {
|
|
Request types.LockRequest
|
|
Callback *future.SetValueFuture[types.RequestID]
|
|
LastErr error
|
|
}
|
|
|
|
func (svc *Service) Acquire(req types.LockRequest, opts ...AcquireOptionFn) (*Mutex, error) {
|
|
var opt = AcquireOption{
|
|
Timeout: time.Second * 10,
|
|
}
|
|
for _, fn := range opts {
|
|
fn(&opt)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
if opt.Timeout != 0 {
|
|
var cancel func()
|
|
ctx, cancel = context.WithTimeout(ctx, opt.Timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
// 就地检测锁是否可用
|
|
svc.lock.Lock()
|
|
defer svc.lock.Unlock()
|
|
|
|
reqID, err := svc.tryAcquireOne(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if reqID != "" {
|
|
return &Mutex{
|
|
svc: svc,
|
|
lockReq: req,
|
|
lockReqID: reqID,
|
|
}, nil
|
|
}
|
|
|
|
// 就地检测失败,那么就需要异步等待锁可用
|
|
info := &acquireInfo{
|
|
Request: req,
|
|
Callback: future.NewSetValue[types.RequestID](),
|
|
}
|
|
svc.acquirings = append(svc.acquirings, info)
|
|
|
|
// 等待的时候不加锁
|
|
svc.lock.Unlock()
|
|
reqID, err = info.Callback.Wait(ctx)
|
|
svc.lock.Lock()
|
|
|
|
if err == nil {
|
|
return &Mutex{
|
|
svc: svc,
|
|
lockReq: req,
|
|
lockReqID: reqID,
|
|
}, nil
|
|
}
|
|
|
|
if err != future.ErrCanceled {
|
|
lo2.Remove(svc.acquirings, info)
|
|
return nil, err
|
|
}
|
|
|
|
// 如果第一次等待是超时错误,那么在锁里再尝试获取一次结果
|
|
reqID, err = info.Callback.TryGetValue()
|
|
if err == nil {
|
|
return &Mutex{
|
|
svc: svc,
|
|
lockReq: req,
|
|
lockReqID: reqID,
|
|
}, nil
|
|
}
|
|
|
|
lo2.Remove(svc.acquirings, info)
|
|
return nil, err
|
|
}
|
|
|
|
func (s *Service) BeginReentrant() *Reentrant {
|
|
return &Reentrant{
|
|
svc: s,
|
|
}
|
|
}
|
|
|
|
func (s *Service) release(reqID types.RequestID, req types.LockRequest) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
s.releaseRequest(reqID, req)
|
|
s.tryAcquirings()
|
|
}
|
|
|
|
func (a *Service) tryAcquirings() {
|
|
for i := 0; i < len(a.acquirings); i++ {
|
|
req := a.acquirings[i]
|
|
|
|
reqID, err := a.tryAcquireOne(req.Request)
|
|
if err != nil {
|
|
req.LastErr = err
|
|
continue
|
|
}
|
|
|
|
req.Callback.SetValue(reqID)
|
|
a.acquirings[i] = nil
|
|
}
|
|
|
|
a.acquirings = lo2.RemoveAllDefault(a.acquirings)
|
|
}
|
|
|
|
func (s *Service) tryAcquireOne(req types.LockRequest) (types.RequestID, error) {
|
|
err := s.testOneRequest(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
reqID := types.RequestID(fmt.Sprintf("%d", s.nextReqID))
|
|
s.nextReqID++
|
|
|
|
s.applyRequest(reqID, req)
|
|
return reqID, nil
|
|
}
|
|
|
|
func (s *Service) testOneRequest(req types.LockRequest) error {
|
|
for _, lock := range req.Locks {
|
|
n, ok := s.provdersTrie.WalkEnd(lock.Path)
|
|
if !ok || n.Value == nil {
|
|
return fmt.Errorf("lock provider not found for path %v", lock.Path)
|
|
}
|
|
|
|
err := n.Value.CanLock(lock)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) applyRequest(reqID types.RequestID, req types.LockRequest) {
|
|
for _, lock := range req.Locks {
|
|
p, _ := s.provdersTrie.WalkEnd(lock.Path)
|
|
p.Value.Lock(reqID, lock)
|
|
}
|
|
}
|
|
|
|
func (s *Service) releaseRequest(reqID types.RequestID, req types.LockRequest) {
|
|
for _, lock := range req.Locks {
|
|
p, _ := s.provdersTrie.WalkEnd(lock.Path)
|
|
p.Value.Unlock(reqID, lock)
|
|
}
|
|
}
|