139 lines
2.9 KiB
Go
139 lines
2.9 KiB
Go
package executormgr
|
||
|
||
import (
|
||
"fmt"
|
||
"sync"
|
||
"time"
|
||
|
||
"gitlink.org.cn/cloudream/common/utils/sync2"
|
||
|
||
schglb "gitlink.org.cn/cloudream/scheduler/common/globals"
|
||
schmod "gitlink.org.cn/cloudream/scheduler/common/models"
|
||
exemq "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor"
|
||
exetsk "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/executor/task"
|
||
mgrmq "gitlink.org.cn/cloudream/scheduler/common/pkgs/mq/manager"
|
||
)
|
||
|
||
type task struct {
|
||
statusChan *sync2.Channel[exetsk.TaskStatus]
|
||
}
|
||
type ExecutorStatus struct {
|
||
executorID schmod.ExecutorID
|
||
tasks map[string]task // key 为 TaskID
|
||
lastReportTime time.Time
|
||
}
|
||
|
||
var ErrWaitReportTimeout = fmt.Errorf("wait report timeout")
|
||
|
||
type Manager struct {
|
||
executors map[schmod.ExecutorID]*ExecutorStatus
|
||
lock sync.Mutex
|
||
exeCli *exemq.Client
|
||
|
||
reportTimeout time.Duration
|
||
}
|
||
|
||
func NewManager(reportTimeout time.Duration) (*Manager, error) {
|
||
exeCli, err := schglb.ExecutorMQPool.Acquire()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("new executor client: %w", err)
|
||
}
|
||
|
||
return &Manager{
|
||
executors: make(map[schmod.ExecutorID]*ExecutorStatus),
|
||
exeCli: exeCli,
|
||
reportTimeout: reportTimeout,
|
||
}, nil
|
||
}
|
||
|
||
func (m *Manager) Report(execID schmod.ExecutorID, taskStatus []mgrmq.ExecutorTaskStatus) {
|
||
m.lock.Lock()
|
||
defer m.lock.Unlock()
|
||
|
||
exec, ok := m.executors[execID]
|
||
if !ok {
|
||
exec = &ExecutorStatus{
|
||
executorID: execID,
|
||
tasks: make(map[string]task),
|
||
}
|
||
m.executors[execID] = exec
|
||
}
|
||
|
||
exec.lastReportTime = time.Now()
|
||
|
||
for _, s := range taskStatus {
|
||
tsk, ok := exec.tasks[s.TaskID]
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// TODO 考虑主动检测channel是否关闭,然后取消task
|
||
if tsk.statusChan.Send(s.Status) != nil {
|
||
delete(exec.tasks, s.TaskID)
|
||
|
||
if len(exec.tasks) == 0 {
|
||
delete(m.executors, execID)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 启动一个Task
|
||
func (m *Manager) StartTask(info exetsk.TaskInfo) *sync2.Channel[exetsk.TaskStatus] {
|
||
m.lock.Lock()
|
||
defer m.lock.Unlock()
|
||
ch := sync2.NewChannel[exetsk.TaskStatus]()
|
||
|
||
resp, err := m.exeCli.StartTask(exemq.NewStartTask(info))
|
||
if err != nil {
|
||
ch.CloseWithError(fmt.Errorf("start task: %w", err))
|
||
return ch
|
||
}
|
||
|
||
exeInfo, ok := m.executors[resp.ExecutorID]
|
||
if !ok {
|
||
exeInfo = &ExecutorStatus{
|
||
executorID: resp.ExecutorID,
|
||
tasks: make(map[string]task),
|
||
lastReportTime: time.Now(),
|
||
}
|
||
m.executors[resp.ExecutorID] = exeInfo
|
||
}
|
||
|
||
exeInfo.tasks[resp.TaskID] = task{
|
||
statusChan: ch,
|
||
}
|
||
|
||
return ch
|
||
}
|
||
|
||
func (m *Manager) Serve() error {
|
||
ticker := time.NewTicker(time.Second)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
func() {
|
||
m.lock.Lock()
|
||
defer m.lock.Unlock()
|
||
|
||
now := time.Now()
|
||
for exeID, exeInfo := range m.executors {
|
||
dt := now.Sub(exeInfo.lastReportTime)
|
||
|
||
if dt < m.reportTimeout {
|
||
continue
|
||
}
|
||
|
||
for _, tsk := range exeInfo.tasks {
|
||
tsk.statusChan.CloseWithError(ErrWaitReportTimeout)
|
||
}
|
||
|
||
delete(m.executors, exeID)
|
||
}
|
||
}()
|
||
}
|
||
}
|
||
}
|