pcm-coordinator/internal/logic/hpc/commithpctasklogic.go

357 lines
9.8 KiB
Go

package hpc
import (
"context"
"fmt"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/zeromicro/go-zero/core/logx"
clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
"regexp"
"strconv"
"strings"
"sync"
"text/template"
"time"
)
type CommitHpcTaskLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
hpcService *service.HpcService
}
const (
statusSaved = "Saved"
statusDeploying = "Deploying"
adapterTypeHPC = "2"
)
type JobRequest struct {
App string `json:"app"`
Common CommonParams `json:"common"`
AppSpecific map[string]interface{} `json:"appSpecific"`
}
type CommonParams struct {
JobName string `json:"jobName"`
Partition string `json:"partition"`
Nodes string `json:"nodes"`
NTasks string `json:"ntasks"`
Time string `json:"time,omitempty"`
App string `json:"app"`
}
func NewCommitHpcTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CommitHpcTaskLogic {
cache := make(map[string]interface{}, 10)
hpcService, err := service.NewHpcService(&svcCtx.Config, svcCtx.Scheduler.HpcStorages, cache)
if err != nil {
return nil
}
return &CommitHpcTaskLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
hpcService: hpcService,
}
}
// 新增:缓存模板对象
var templateCache = sync.Map{}
func (l *CommitHpcTaskLogic) getClusterInfo(clusterID string) (*types.ClusterInfo, *types.AdapterInfo, error) {
var clusterInfo types.ClusterInfo
if err := l.svcCtx.DbEngin.Table("t_cluster").Where("id = ?", clusterID).First(&clusterInfo).Error; err != nil {
return nil, nil, fmt.Errorf("cluster query failed: %w", err)
}
if clusterInfo.Id == "" {
return nil, nil, errors.New("cluster not found")
}
var adapterInfo types.AdapterInfo
if err := l.svcCtx.DbEngin.Table("t_adapter").Where("id = ?", clusterInfo.AdapterId).First(&adapterInfo).Error; err != nil {
return nil, nil, fmt.Errorf("adapter query failed: %w", err)
}
if adapterInfo.Id == "" {
return nil, nil, errors.New("adapter not found")
}
return &clusterInfo, &adapterInfo, nil
}
// 自定义函数映射
func createFuncMap() template.FuncMap {
return template.FuncMap{
"regexMatch": regexMatch,
"required": required,
"error": errorHandler,
"default": defaultHandler,
}
}
func extractUserError(originalErr error) error {
// 尝试匹配模板引擎返回的错误格式
re := regexp.MustCompile(`error calling \w+: (.*)$`)
matches := re.FindStringSubmatch(originalErr.Error())
if len(matches) > 1 {
return errors.New(matches[1])
}
return originalErr
}
// 正则匹配函数
func regexMatch(pattern string) *regexp.Regexp {
return regexp.MustCompile(pattern)
}
// 必填字段检查
func required(msg string, val interface{}) (interface{}, error) {
if val == nil || val == "" {
return nil, errors.New(msg)
}
return val, nil
}
// 错误处理函数
func errorHandler(msg string) (string, error) {
return "", errors.New(msg)
}
// 默认值处理函数
func defaultHandler(defaultVal interface{}, val interface{}) interface{} {
switch v := val.(type) {
case nil:
return defaultVal
case string:
if v == "" {
return defaultVal
}
case int:
if v == 0 {
return defaultVal
}
// 可根据需要添加其他类型判断
}
return val
}
func (l *CommitHpcTaskLogic) RenderJobScript(templateContent string, req *JobRequest) (string, error) {
// 使用缓存模板
tmpl, ok := templateCache.Load(templateContent)
if !ok {
parsedTmpl, err := template.New("slurmTemplate").Funcs(createFuncMap()).Parse(templateContent)
if err != nil {
return "", err
}
templateCache.Store(templateContent, parsedTmpl)
tmpl = parsedTmpl
}
params := map[string]interface{}{
"Common": req.Common,
"App": req.AppSpecific,
}
var buf strings.Builder
if err := tmpl.(*template.Template).Execute(&buf, params); err != nil {
log.Error().Err(err).Msg("模板渲染失败")
return "", extractUserError(err)
}
return buf.String(), nil
}
func ConvertToJobRequest(job *types.CommitHpcTaskReq) (JobRequest, error) {
required := []string{"jobName", "nodes", "ntasks"}
for _, field := range required {
if job.Parameters[field] == "" {
return JobRequest{}, fmt.Errorf("%s is empty", field)
}
}
return JobRequest{
App: job.App,
Common: CommonParams{
JobName: job.Parameters["jobName"],
Partition: job.Parameters["partition"],
Nodes: job.Parameters["nodes"],
NTasks: job.Parameters["ntasks"],
Time: job.Parameters["time"],
App: job.App,
},
AppSpecific: utils.MpaStringToInterface(job.Parameters),
}, nil
}
func (l *CommitHpcTaskLogic) SaveHpcTaskToDB(req *types.CommitHpcTaskReq, jobScript, jobId, workDir string) (taskId string, err error) {
// 使用事务确保数据一致性
tx := l.svcCtx.DbEngin.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
err = fmt.Errorf("transaction panic: %v", r)
} else if err != nil {
tx.Rollback()
}
}()
userID, _ := strconv.ParseInt(req.Parameters["UserId"], 10, 64)
taskID := utils.GenSnowflakeID()
taskModel := models.Task{
Id: taskID,
Name: req.Name,
Description: req.Description,
CommitTime: time.Now(),
Status: statusSaved,
AdapterTypeDict: adapterTypeHPC,
UserId: userID,
}
if err = tx.Table("task").Create(&taskModel).Error; err != nil {
return "", fmt.Errorf("failed to create task: %w", err)
}
clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
if err != nil {
return "", err
}
paramsJSON, err := jsoniter.MarshalToString(req)
if err != nil {
return "", fmt.Errorf("failed to marshal parameters: %w", err)
}
clusterID := utils.StringToInt64(clusterInfo.Id)
hpcTask := models.TaskHpc{
Id: utils.GenSnowflakeID(),
TaskId: taskID,
AdapterId: clusterInfo.AdapterId,
AdapterName: adapterInfo.Name,
ClusterId: clusterID,
ClusterName: clusterInfo.Name,
Name: taskModel.Name,
Backend: req.Backend,
OperateType: req.OperateType,
CmdScript: req.Parameters["cmdScript"],
WallTime: req.Parameters["wallTime"],
AppType: req.Parameters["appType"],
AppName: req.App,
Queue: req.Parameters["queue"],
SubmitType: req.Parameters["submitType"],
NNode: req.Parameters["nNode"],
Account: clusterInfo.Username,
StdInput: req.Parameters["stdInput"],
Partition: req.Parameters["partition"],
CreatedTime: time.Now(),
UpdatedTime: time.Now(),
Status: statusDeploying,
UserId: userID,
Params: paramsJSON,
Script: jobScript,
JobId: jobId,
WorkDir: workDir,
}
if err = tx.Table("task_hpc").Create(&hpcTask).Error; err != nil {
return "", fmt.Errorf("failed to create HPC task: %w", err)
}
noticeInfo := clientCore.NoticeInfo{
AdapterId: clusterInfo.AdapterId,
AdapterName: adapterInfo.Name,
ClusterId: clusterID,
ClusterName: clusterInfo.Name,
NoticeType: "create",
TaskName: req.Name,
TaskId: taskID,
Incident: "任务创建中",
CreatedTime: time.Now(),
}
if err = tx.Table("t_notice").Create(&noticeInfo).Error; err != nil {
return "", fmt.Errorf("failed to create notice: %w", err)
}
if err = tx.Commit().Error; err != nil {
return "", fmt.Errorf("transaction commit failed: %w", err)
}
return utils.Int64ToString(taskID), nil
}
func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *types.CommitHpcTaskResp, err error) {
jobName := generateJobName(req)
req.Parameters["jobName"] = jobName
// 获取集群和适配器信息
clusterInfo, adapterInfo, err := l.getClusterInfo(req.ClusterId)
if err != nil {
return nil, err
}
scriptContent := req.ScriptContent
if scriptContent == "" {
// 获取模板
var templateInfo types.HpcAppTemplateInfo
tx := l.svcCtx.DbEngin.Table("hpc_app_template").
Where("cluster_id = ? and app = ? ", req.ClusterId, req.App)
if req.OperateType != "" {
tx.Where("app_type = ?", req.OperateType)
}
if err := tx.First(&templateInfo).Error; err != nil {
return nil, fmt.Errorf("failed to get template: %w", err)
}
// 转换请求参数
jobRequest, err := ConvertToJobRequest(req)
if err != nil {
return nil, err
}
// 渲染脚本
script, err := l.RenderJobScript(templateInfo.Content, &jobRequest)
if err != nil {
return nil, err
}
scriptContent = script
}
q, _ := jsoniter.MarshalToString(scriptContent)
submitQ := types.SubmitHpcTaskReq{
App: req.App,
ClusterId: req.ClusterId,
JobName: jobName,
ScriptContent: scriptContent,
Parameters: req.Parameters,
Backend: req.Backend,
}
log.Info().Msgf("Submitting HPC task to cluster %s with params: %s", clusterInfo.Name, q)
resp, err = l.hpcService.HpcExecutorAdapterMap[adapterInfo.Id].SubmitTask(l.ctx, submitQ)
if err != nil {
log.Error().Err(err).Msgf("提交HPC任务失败, cluster: %s, jobName: %s, scriptContent: %s", clusterInfo.Name, jobName, scriptContent)
return nil, fmt.Errorf("网络请求失败,请稍后重试")
}
jobID := resp.Data.JobInfo["jobId"]
workDir := resp.Data.JobInfo["jobDir"]
taskID, err := l.SaveHpcTaskToDB(req, scriptContent, jobID, workDir)
if err != nil {
log.Error().Msgf("Failed to save task to DB: %v", err)
return nil, fmt.Errorf("db save failed: %w", err)
}
resp.Data.JobInfo["taskId"] = taskID
return resp, nil
}
func generateJobName(req *types.CommitHpcTaskReq) string {
if req.OperateType == "" {
return req.Name
}
return req.Name + "_" + req.OperateType
}