140 lines
2.9 KiB
Go
140 lines
2.9 KiB
Go
package advisormgr
|
||
|
||
import (
|
||
"fmt"
|
||
"sync"
|
||
"time"
|
||
|
||
"gitlink.org.cn/cloudream/common/utils/sync2"
|
||
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
|
||
schmod "gitlink.org.cn/cloudream/scheduler/common/models"
|
||
advmq "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/advisor"
|
||
advtsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/advisor/task"
|
||
mgrmq "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/manager"
|
||
)
|
||
|
||
type task struct {
|
||
statusChan *sync2.Channel[advtsk.TaskStatus]
|
||
}
|
||
|
||
type AdvisorInfo struct {
|
||
advisorID schmod.AdvisorID
|
||
tasks map[string]task // key 为 TaskID
|
||
lastReportTime time.Time
|
||
}
|
||
|
||
var ErrWaitReportTimeout = fmt.Errorf("wait report timeout")
|
||
|
||
type Manager struct {
|
||
advisors map[schmod.AdvisorID]*AdvisorInfo
|
||
lock sync.Mutex
|
||
advCli *advmq.Client
|
||
|
||
reportTimeout time.Duration
|
||
}
|
||
|
||
func NewManager(reportTimeout time.Duration) (*Manager, error) {
|
||
advCli, err := schglb.AdvisorMQPool.Acquire()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("new executor client: %w", err)
|
||
}
|
||
|
||
return &Manager{
|
||
advisors: make(map[schmod.AdvisorID]*AdvisorInfo),
|
||
advCli: advCli,
|
||
reportTimeout: reportTimeout,
|
||
}, nil
|
||
}
|
||
|
||
func (m *Manager) Report(advID schmod.AdvisorID, taskStatus []mgrmq.AdvisorTaskStatus) {
|
||
m.lock.Lock()
|
||
defer m.lock.Unlock()
|
||
|
||
adv, ok := m.advisors[advID]
|
||
if !ok {
|
||
adv = &AdvisorInfo{
|
||
advisorID: advID,
|
||
tasks: make(map[string]task),
|
||
}
|
||
m.advisors[advID] = adv
|
||
}
|
||
|
||
adv.lastReportTime = time.Now()
|
||
|
||
for _, s := range taskStatus {
|
||
tsk, ok := adv.tasks[s.TaskID]
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// TODO 考虑主动检测channel是否关闭,然后取消task
|
||
if tsk.statusChan.Send(s.Status) != nil {
|
||
delete(adv.tasks, s.TaskID)
|
||
|
||
if len(adv.tasks) == 0 {
|
||
delete(m.advisors, advID)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 启动一个Task
|
||
func (m *Manager) StartTask(info advtsk.TaskInfo) *sync2.Channel[advtsk.TaskStatus] {
|
||
m.lock.Lock()
|
||
defer m.lock.Unlock()
|
||
|
||
ch := sync2.NewChannel[advtsk.TaskStatus]()
|
||
|
||
resp, err := m.advCli.StartTask(advmq.NewStartTask(info))
|
||
if err != nil {
|
||
ch.CloseWithError(fmt.Errorf("start task: %w", err))
|
||
return ch
|
||
}
|
||
|
||
exeInfo, ok := m.advisors[resp.AdvisorID]
|
||
if !ok {
|
||
exeInfo = &AdvisorInfo{
|
||
advisorID: resp.AdvisorID,
|
||
tasks: make(map[string]task),
|
||
lastReportTime: time.Now(),
|
||
}
|
||
m.advisors[resp.AdvisorID] = exeInfo
|
||
}
|
||
|
||
exeInfo.tasks[resp.TaskID] = task{
|
||
statusChan: ch,
|
||
}
|
||
|
||
return ch
|
||
}
|
||
|
||
func (m *Manager) Serve() error {
|
||
ticker := time.NewTicker(time.Second)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
func() {
|
||
m.lock.Lock()
|
||
defer m.lock.Unlock()
|
||
|
||
now := time.Now()
|
||
for exeID, exeInfo := range m.advisors {
|
||
dt := now.Sub(exeInfo.lastReportTime)
|
||
|
||
if dt < m.reportTimeout {
|
||
continue
|
||
}
|
||
|
||
for _, tsk := range exeInfo.tasks {
|
||
tsk.statusChan.CloseWithError(ErrWaitReportTimeout)
|
||
}
|
||
|
||
delete(m.advisors, exeID)
|
||
}
|
||
}()
|
||
}
|
||
}
|
||
}
|