pcm-coordinator/internal/logic/hpc/commithpctasklogic.go

159 lines
5.0 KiB
Go

package hpc
import (
"context"
"errors"
jsoniter "github.com/json-iterator/go"
clientCore "gitlink.org.cn/JointCloud/pcm-coordinator/client"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
"strconv"
"time"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type CommitHpcTaskLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
hpcService *service.HpcService
}
func NewCommitHpcTaskLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CommitHpcTaskLogic {
cache := make(map[string]interface{}, 10)
hpcService, err := service.NewHpcService(&svcCtx.Config, svcCtx.Scheduler.HpcStorages, cache)
if err != nil {
return nil
}
return &CommitHpcTaskLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
hpcService: hpcService,
}
}
func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *types.CommitHpcTaskResp, err error) {
reqStr, _ := jsoniter.MarshalToString(req)
yaml := utils.StringToYaml(reqStr)
var clusterInfo types.ClusterInfo
l.svcCtx.DbEngin.Raw("SELECT * FROM `t_cluster` where id = ?", req.ClusterId).First(&clusterInfo)
if len(clusterInfo.Id) == 0 {
return resp, errors.New("cluster not found")
}
// 构建主任务结构体
userId, _ := strconv.ParseInt(req.Parameters["UserId"], 10, 64)
taskModel := models.Task{
Name: req.Name,
Description: req.Description,
CommitTime: time.Now(),
Status: "Saved",
AdapterTypeDict: "2",
UserId: userId,
YamlString: *yaml,
}
// 保存任务数据到数据库
tx := l.svcCtx.DbEngin.Create(&taskModel)
if tx.Error != nil {
return nil, tx.Error
}
var adapterInfo types.AdapterInfo
l.svcCtx.DbEngin.Raw("SELECT * FROM `t_adapter` where id = ?", clusterInfo.AdapterId).Scan(&adapterInfo)
if adapterInfo.Id == "" {
return resp, errors.New("adapter not found")
}
clusterId, err := strconv.ParseInt(req.ClusterId, 10, 64)
cardCount, _ := strconv.ParseInt(req.Parameters["cardCount"], 10, 64)
timelimit, _ := strconv.ParseInt(req.Parameters["timeLimit"], 10, 64)
hpcInfo := models.TaskHpc{
TaskId: taskModel.Id,
AdapterId: clusterInfo.AdapterId,
AdapterName: adapterInfo.Name,
ClusterId: clusterId,
ClusterName: clusterInfo.Name,
Name: taskModel.Name,
Backend: req.Backend,
OperateType: req.OperateType,
CmdScript: req.Parameters["cmdScript"],
CardCount: cardCount,
WorkDir: req.Parameters["workDir"],
WallTime: req.Parameters["wallTime"],
AppType: req.Parameters["appType"],
AppName: req.Parameters["appName"],
Queue: req.Parameters["queue"],
SubmitType: req.Parameters["submitType"],
NNode: req.Parameters["nNode"],
Account: clusterInfo.Username,
StdInput: req.Parameters["stdInput"],
Partition: req.Parameters["partition"],
CreatedTime: time.Now(),
UpdatedTime: time.Now(),
Status: "Deploying",
TimeLimit: timelimit,
UserId: userId,
YamlString: *yaml,
}
hpcInfo.WorkDir = clusterInfo.WorkDir + req.Parameters["WorkDir"]
tx = l.svcCtx.DbEngin.Create(&hpcInfo)
if tx.Error != nil {
return nil, tx.Error
}
// 保存操作记录
noticeInfo := clientCore.NoticeInfo{
AdapterId: clusterInfo.AdapterId,
AdapterName: adapterInfo.Name,
ClusterId: clusterId,
ClusterName: clusterInfo.Name,
NoticeType: "create",
TaskName: req.Name,
Incident: "任务创建中",
CreatedTime: time.Now(),
}
result := l.svcCtx.DbEngin.Table("t_notice").Create(&noticeInfo)
if result.Error != nil {
logx.Errorf("Task creation failure, err: %v", result.Error)
}
// 数据上链
// 查询资源价格
//var price int64
//l.svcCtx.DbEngin.Raw("select price from `resource_cost` where resource_id = ?", clusterId).Scan(&price)
//bytes, _ := json.Marshal(taskModel)
//remoteUtil.Evidence(remoteUtil.EvidenceParam{
// UserIp: req.Parameters["UserIp"],
// Url: l.svcCtx.Config.BlockChain.Url,
// ContractAddress: l.svcCtx.Config.BlockChain.ContractAddress,
// FunctionName: l.svcCtx.Config.BlockChain.FunctionName,
// Type: l.svcCtx.Config.BlockChain.Type,
// Token: req.Parameters["Token"],
// Amount: price,
// Args: []string{strconv.FormatInt(taskModel.Id, 10), string(bytes)},
//})
// 提交job到指定集群
logx.Info("提交job到指定集群")
resp, err = l.hpcService.HpcExecutorAdapterMap[adapterInfo.Id].SubmitTask(context.Background(), *req)
if err != nil {
return nil, err
}
// 更新任务状态
updates := l.svcCtx.DbEngin.Model(&hpcInfo).Updates(models.TaskHpc{
Id: hpcInfo.Id,
JobId: resp.Data.JobInfo["jobId"],
WorkDir: resp.Data.JobInfo["jobDir"],
})
if updates.Error != nil {
return nil, updates.Error
}
resp.Data.JobInfo["taskId"] = strconv.FormatInt(taskModel.Id, 10)
return resp, nil
}