forked from JointCloud/pcm-coordinator
357 lines
9.8 KiB
Go
357 lines
9.8 KiB
Go
package hpc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
jsoniter "github.com/json-iterator/go"
|
|
"github.com/pkg/errors"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"text/template"
|
|
"time"
|
|
)
|
|
|
|
type CommitHpcTaskLogic struct {
|
|
logx.Logger
|
|
ctx context.Context
|
|
svcCtx *svc.ServiceContext
|
|
hpcService *service.HpcService
|
|
}
|
|
|
|
const (
|
|
statusSaved = "Saved"
|
|
statusDeploying = "Deploying"
|
|
adapterTypeHPC = "2"
|
|
)
|
|
|
|
type JobRequest struct {
|
|
App string `json:"app"`
|
|
Common CommonParams `json:"common"`
|
|
AppSpecific map[string]interface{} `json:"appSpecific"`
|
|
}
|
|
type CommonParams struct {
|
|
JobName string `json:"jobName"`
|
|
Partition string `json:"partition"`
|
|
Nodes string `json:"nodes"`
|
|
NTasks string `json:"ntasks"`
|
|
Time string `json:"time,omitempty"`
|
|
App string `json:"app"`
|
|
}
|
|
|
|
func NewCommitHpcTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CommitHpcTaskLogic {
|
|
cache := make(map[string]interface{}, 10)
|
|
hpcService, err := service.NewHpcService(&svcCtx.Config, svcCtx.Scheduler.HpcStorages, cache)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return &CommitHpcTaskLogic{
|
|
Logger: logx.WithContext(ctx),
|
|
ctx: ctx,
|
|
svcCtx: svcCtx,
|
|
hpcService: hpcService,
|
|
}
|
|
}
|
|
|
|
// 新增:缓存模板对象
|
|
var templateCache = sync.Map{}
|
|
|
|
func (l *CommitHpcTaskLogic) getClusterInfo(clusterID string) (*types.ClusterInfo, *types.AdapterInfo, error) {
|
|
var clusterInfo types.ClusterInfo
|
|
if err := l.svcCtx.DbEngin.Table("t_cluster").Where("id = ?", clusterID).First(&clusterInfo).Error; err != nil {
|
|
return nil, nil, fmt.Errorf("cluster query failed: %w", err)
|
|
}
|
|
|
|
if clusterInfo.Id == "" {
|
|
return nil, nil, errors.New("cluster not found")
|
|
}
|
|
|
|
var adapterInfo types.AdapterInfo
|
|
if err := l.svcCtx.DbEngin.Table("t_adapter").Where("id = ?", clusterInfo.AdapterId).First(&adapterInfo).Error; err != nil {
|
|
return nil, nil, fmt.Errorf("adapter query failed: %w", err)
|
|
}
|
|
|
|
if adapterInfo.Id == "" {
|
|
return nil, nil, errors.New("adapter not found")
|
|
}
|
|
|
|
return &clusterInfo, &adapterInfo, nil
|
|
}
|
|
|
|
// 自定义函数映射
|
|
func createFuncMap() template.FuncMap {
|
|
return template.FuncMap{
|
|
"regexMatch": regexMatch,
|
|
"required": required,
|
|
"error": errorHandler,
|
|
"default": defaultHandler,
|
|
}
|
|
}
|
|
func extractUserError(originalErr error) error {
|
|
// 尝试匹配模板引擎返回的错误格式
|
|
re := regexp.MustCompile(`error calling \w+: (.*)$`)
|
|
matches := re.FindStringSubmatch(originalErr.Error())
|
|
if len(matches) > 1 {
|
|
return errors.New(matches[1])
|
|
}
|
|
return originalErr
|
|
}
|
|
|
|
// 正则匹配函数
|
|
func regexMatch(pattern string) *regexp.Regexp {
|
|
return regexp.MustCompile(pattern)
|
|
}
|
|
|
|
// 必填字段检查
|
|
func required(msg string, val interface{}) (interface{}, error) {
|
|
if val == nil || val == "" {
|
|
return nil, errors.New(msg)
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
// 错误处理函数
|
|
func errorHandler(msg string) (string, error) {
|
|
return "", errors.New(msg)
|
|
}
|
|
|
|
// 默认值处理函数
|
|
func defaultHandler(defaultVal interface{}, val interface{}) interface{} {
|
|
switch v := val.(type) {
|
|
case nil:
|
|
return defaultVal
|
|
case string:
|
|
if v == "" {
|
|
return defaultVal
|
|
}
|
|
case int:
|
|
if v == 0 {
|
|
return defaultVal
|
|
}
|
|
// 可根据需要添加其他类型判断
|
|
}
|
|
return val
|
|
}
|
|
func (l *CommitHpcTaskLogic) RenderJobScript(templateContent string, req *JobRequest) (string, error) {
|
|
// 使用缓存模板
|
|
tmpl, ok := templateCache.Load(templateContent)
|
|
if !ok {
|
|
parsedTmpl, err := template.New("slurmTemplate").Funcs(createFuncMap()).Parse(templateContent)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
templateCache.Store(templateContent, parsedTmpl)
|
|
tmpl = parsedTmpl
|
|
}
|
|
|
|
params := map[string]interface{}{
|
|
"Common": req.Common,
|
|
"App": req.AppSpecific,
|
|
}
|
|
|
|
var buf strings.Builder
|
|
if err := tmpl.(*template.Template).Execute(&buf, params); err != nil {
|
|
log.Error().Err(err).Msg("模板渲染失败")
|
|
return "", extractUserError(err)
|
|
}
|
|
return buf.String(), nil
|
|
}
|
|
|
|
func ConvertToJobRequest(job *types.CommitHpcTaskReq) (JobRequest, error) {
|
|
required := []string{"jobName", "nodes", "ntasks"}
|
|
for _, field := range required {
|
|
if job.Parameters[field] == "" {
|
|
return JobRequest{}, fmt.Errorf("%s is empty", field)
|
|
}
|
|
}
|
|
|
|
return JobRequest{
|
|
App: job.App,
|
|
Common: CommonParams{
|
|
JobName: job.Parameters["jobName"],
|
|
Partition: job.Parameters["partition"],
|
|
Nodes: job.Parameters["nodes"],
|
|
NTasks: job.Parameters["ntasks"],
|
|
Time: job.Parameters["time"],
|
|
App: job.App,
|
|
},
|
|
AppSpecific: utils.MpaStringToInterface(job.Parameters),
|
|
}, nil
|
|
}
|
|
|
|
func (l *CommitHpcTaskLogic) SaveHpcTaskToDB(req *types.CommitHpcTaskReq, jobScript, jobId, workDir string) (taskId string, err error) {
|
|
// 使用事务确保数据一致性
|
|
tx := l.svcCtx.DbEngin.Begin()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
tx.Rollback()
|
|
err = fmt.Errorf("transaction panic: %v", r)
|
|
} else if err != nil {
|
|
tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
userID, _ := strconv.ParseInt(req.Parameters["UserId"], 10, 64)
|
|
taskID := utils.GenSnowflakeID()
|
|
taskModel := models.Task{
|
|
Id: taskID,
|
|
Name: req.Name,
|
|
Description: req.Description,
|
|
CommitTime: time.Now(),
|
|
Status: statusSaved,
|
|
AdapterTypeDict: adapterTypeHPC,
|
|
UserId: userID,
|
|
}
|
|
|
|
if err = tx.Table("task").Create(&taskModel).Error; err != nil {
|
|
return "", fmt.Errorf("failed to create task: %w", err)
|
|
}
|
|
|
|
clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
paramsJSON, err := jsoniter.MarshalToString(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal parameters: %w", err)
|
|
}
|
|
|
|
clusterID := utils.StringToInt64(clusterInfo.Id)
|
|
hpcTask := models.TaskHpc{
|
|
Id: utils.GenSnowflakeID(),
|
|
TaskId: taskID,
|
|
AdapterId: clusterInfo.AdapterId,
|
|
AdapterName: adapterInfo.Name,
|
|
ClusterId: clusterID,
|
|
ClusterName: clusterInfo.Name,
|
|
Name: taskModel.Name,
|
|
Backend: req.Backend,
|
|
OperateType: req.OperateType,
|
|
CmdScript: req.Parameters["cmdScript"],
|
|
WallTime: req.Parameters["wallTime"],
|
|
AppType: req.Parameters["appType"],
|
|
AppName: req.App,
|
|
Queue: req.Parameters["queue"],
|
|
SubmitType: req.Parameters["submitType"],
|
|
NNode: req.Parameters["nNode"],
|
|
Account: clusterInfo.Username,
|
|
StdInput: req.Parameters["stdInput"],
|
|
Partition: req.Parameters["partition"],
|
|
CreatedTime: time.Now(),
|
|
UpdatedTime: time.Now(),
|
|
Status: statusDeploying,
|
|
UserId: userID,
|
|
Params: paramsJSON,
|
|
Script: jobScript,
|
|
JobId: jobId,
|
|
WorkDir: workDir,
|
|
}
|
|
|
|
if err = tx.Table("task_hpc").Create(&hpcTask).Error; err != nil {
|
|
return "", fmt.Errorf("failed to create HPC task: %w", err)
|
|
}
|
|
|
|
noticeInfo := clientCore.NoticeInfo{
|
|
AdapterId: clusterInfo.AdapterId,
|
|
AdapterName: adapterInfo.Name,
|
|
ClusterId: clusterID,
|
|
ClusterName: clusterInfo.Name,
|
|
NoticeType: "create",
|
|
TaskName: req.Name,
|
|
TaskId: taskID,
|
|
Incident: "任务创建中",
|
|
CreatedTime: time.Now(),
|
|
}
|
|
|
|
if err = tx.Table("t_notice").Create(¬iceInfo).Error; err != nil {
|
|
return "", fmt.Errorf("failed to create notice: %w", err)
|
|
}
|
|
|
|
if err = tx.Commit().Error; err != nil {
|
|
return "", fmt.Errorf("transaction commit failed: %w", err)
|
|
}
|
|
|
|
return utils.Int64ToString(taskID), nil
|
|
}
|
|
|
|
func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *types.CommitHpcTaskResp, err error) {
|
|
jobName := generateJobName(req)
|
|
req.Parameters["jobName"] = jobName
|
|
|
|
// 获取集群和适配器信息
|
|
clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
scriptContent := req.ScriptContent
|
|
if scriptContent == "" {
|
|
// 获取模板
|
|
var templateInfo types.HpcAppTemplateInfo
|
|
tx := l.svcCtx.DbEngin.Table("hpc_app_template").
|
|
Where("cluster_id = ? and app = ? ", req.ClusterId, req.App)
|
|
if req.OperateType != "" {
|
|
tx.Where("app_type = ?", req.OperateType)
|
|
}
|
|
if err := tx.First(&templateInfo).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get template: %w", err)
|
|
}
|
|
|
|
// 转换请求参数
|
|
jobRequest, err := ConvertToJobRequest(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 渲染脚本
|
|
script, err := l.RenderJobScript(templateInfo.Content, &jobRequest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
scriptContent = script
|
|
}
|
|
|
|
q, _ := jsoniter.MarshalToString(scriptContent)
|
|
submitQ := types.SubmitHpcTaskReq{
|
|
App: req.App,
|
|
ClusterId: req.ClusterId,
|
|
JobName: jobName,
|
|
ScriptContent: scriptContent,
|
|
Parameters: req.Parameters,
|
|
Backend: req.Backend,
|
|
}
|
|
log.Info().Msgf("Submitting HPC task to cluster %s with params: %s", clusterInfo.Name, q)
|
|
resp, err = l.hpcService.HpcExecutorAdapterMap[adapterInfo.Id].SubmitTask(l.ctx, submitQ)
|
|
if err != nil {
|
|
log.Error().Err(err).Msgf("提交HPC任务失败, cluster: %s, jobName: %s, scriptContent: %s", clusterInfo.Name, jobName, scriptContent)
|
|
return nil, fmt.Errorf("网络请求失败,请稍后重试")
|
|
}
|
|
|
|
jobID := resp.Data.JobInfo["jobId"]
|
|
workDir := resp.Data.JobInfo["jobDir"]
|
|
taskID, err := l.SaveHpcTaskToDB(req, scriptContent, jobID, workDir)
|
|
if err != nil {
|
|
log.Error().Msgf("Failed to save task to DB: %v", err)
|
|
return nil, fmt.Errorf("db save failed: %w", err)
|
|
}
|
|
|
|
resp.Data.JobInfo["taskId"] = taskID
|
|
return resp, nil
|
|
}
|
|
|
|
func generateJobName(req *types.CommitHpcTaskReq) string {
|
|
if req.OperateType == "" {
|
|
return req.Name
|
|
}
|
|
return req.Name + "_" + req.OperateType
|
|
}
|