SchedulingSimulator/commit_task_jointcloud_isda...

911 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import concurrent.futures
import time
import logging
import threading
from uuid import uuid4
from typing import Dict, List, Optional
import requests
import os
import json
# -------------------------- 全局配置与常量定义 --------------------------
# 日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# 任务状态定义
TASK_STATUS = {
"SUBMITTED": "待提交", # 初始状态
"SUBMITTING": "提交中", # 提交过程中
"SUCCEED": "提交成功", # 云际确认成功
"FAILED": "提交失败", # 云际确认失败
"RETRY_EXHAUSTED": "重试耗尽" # 超过最大失败次数
}
# 全局任务字典key=target_idvalue=任务详情)
task_map: Dict[str, Dict] = {}
task_map_lock = threading.Lock() # 任务字典线程锁
# 全局数据集映射key=file_locationvalue=DatasetInfo实例- 新增
dataset_map: Dict[str, Dict] = {}
dataset_lock = threading.Lock() # 数据集映射线程锁
# API配置
API_CONFIG = {
"login": {
"url": "http://119.45.255.234:30180/jcc-admin/admin/login",
"timeout": 10
},
"create_package": {
"url": "http://119.45.255.234:30180/jsm/jobSet/createPackage",
"timeout": 15
},
"upload_file": {
"url": "http://119.45.255.234:30180/jcs/object/upload",
"timeout": 3000
},
"notify_upload": {
"url": "http://119.45.255.234:30180/jsm/jobSet/notifyUploaded",
"timeout": 15
},
"bind_cluster": {
"url": "http://119.45.255.234:30180/jsm/jobSet/binding",
"timeout": 15
},
"query_binding": {
"url": "http://119.45.255.234:30180/jsm/jobSet/queryBinding",
"timeout": 15
},
"submit_task": {
"url": "http://119.45.255.234:30180/jsm/jobSet/submit",
"timeout": 100
},
"task_details": {
"url": "http://119.45.255.234:30180/pcm/v1/core/task/details",
"timeout": 15
}
}
# 集群资源配置key=集群IDvalue=总资源/可用资源)
cluster_resources: Dict[str, Dict] = {
"1790300942428540928": { # modelarts集群
"total": {"CPU": 512, "MEMORY": 1024, "NPU": 8},
"available": {"CPU": 96, "MEMORY": 1024, "NPU": 8}
}
}
cluster_lock = threading.Lock() # 集群资源线程锁
# -------------------------- 数据结构定义 --------------------------
class DatasetInfo(dict):
"""数据集信息结构"""
def __init__(self, file_location: str, name: str, size: float, **kwargs):
super().__init__()
self["file_location"] = file_location # 本地路径(主键)
self["id"] = kwargs.get("id", str(uuid4())) # 数据集唯一标识
self["name"] = name # 数据集名称
self["size"] = size # 大小(字节)
self["is_uploaded"] = kwargs.get("is_uploaded", False) # 是否已上传
self["upload_cluster"] = kwargs.get("upload_cluster", []) # 上传的集群
self["upload_time"] = kwargs.get("upload_time") # 上传时间
self["description"] = kwargs.get("description") # 描述
self["dataset_target_id"] = kwargs.get("dataset_target_id") # 绑定ID新增
class AlgorithmInfo(dict):
"""算法信息结构"""
def __init__(self, cluster: str, id: str, name: str, **kwargs):
super().__init__()
self["cluster"] = cluster # 所属集群
self["id"] = id # 算法唯一标识
self["son_id"] = kwargs.get("son_id", "") # 子算法ID
self["name"] = name # 算法名称
class TaskInfo(dict):
"""任务信息结构"""
def __init__(self, task_name: str, dataset_name: str, son_code_id: int, resource: Dict, **kwargs):
super().__init__()
self["target_id"] = kwargs.get("target_id", str(uuid4())) # 任务唯一ID
self["task_name"] = task_name # 任务名称
self["package_name"] = kwargs.get("package_name", f"{task_name.lower()}-pkg") # 文件夹名称
self["dataset_name"] = dataset_name # 关联数据集名称
self["son_code_id"] = son_code_id # 子算法ID直接从模板获取
self["resource"] = resource # 资源需求CPU/MEMORY/NPU等
self["status"] = TASK_STATUS["SUBMITTED"] # 初始状态:待提交
self["submit_time"] = kwargs.get("submit_time", time.strftime("%Y-%m-%d %H:%M:%S")) # 提交时间
self["success_time"] = None # 成功时间(成功时填充)
self["third_party_task_id"] = "" # 云际任务ID提交后填充
self["file_location"] = kwargs.get("file_location", "") # 本地文件路径
self["error_msg"] = "" # 错误信息
self["fail_count"] = 0 # 失败次数
self["max_fail_threshold"] = kwargs.get("max_fail_threshold", 3) # 最大失败阈值
self["cluster_id"] = "" # 提交的集群ID提交时填充
# -------------------------- 工具方法 --------------------------
def generate_task_templates() -> List[Dict]:
"""生成任务静态数据模板包含son_code_id"""
return [
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AA",
"dataset_name": "data1.zip",
"son_code_id": 1217, # 子算法ID直接定义在模板中
"resource": {"CPU": 12, "MEMORY": 24, "NPU": 1},
"file_location": "D:/数据集/cnn数据集/data1/"
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AB",
"dataset_name": "cifar-10-python.tar.gz",
"son_code_id": 1167,
"resource": {"CPU": 12, "MEMORY": 24, "NPU": 1},
"file_location": "D:/数据集/cnn数据集/data2/"
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AC",
"dataset_name": "cifar-100-python.tar.gz",
"son_code_id": 1169,
"resource": {"CPU": 12, "MEMORY": 24, "NPU": 1},
"file_location": "D:/数据集/cnn数据集/data3/"
}
]
def load_tasks_to_queue(templates: List[Dict]) -> None:
"""将任务静态数据加载到任务队列使用son_code_id"""
global task_map
with task_map_lock:
task_map.clear()
for idx, template in enumerate(templates):
try:
# 检查必填字段更新为son_code_id
required_fields = ["task_name_template", "prefix", "dataset_name", "son_code_id", "resource",
"file_location"]
missing_fields = [f for f in required_fields if f not in template]
if missing_fields:
logger.warning(f"跳过无效任务模板(索引{idx}):缺少字段 {missing_fields}")
continue
task_name = template["task_name_template"].format(prefix=template["prefix"])
task = TaskInfo(
task_name=task_name,
dataset_name=template["dataset_name"],
son_code_id=template["son_code_id"], # 直接使用模板中的子算法ID
resource=template["resource"],
file_location=template["file_location"]
)
task_map[task["target_id"]] = task
logger.info(f"任务入队 | task_name: {task_name} | target_id: {task['target_id']}")
except Exception as e:
logger.error(f"加载任务模板失败(索引{idx}{str(e)}")
logger.info(f"任务队列加载完成 | 共 {len(task_map)} 个有效任务")
def select_cluster(task_resource: Dict) -> Optional[str]:
"""根据任务资源需求选择合适的集群"""
with cluster_lock:
for cluster_id, cluster in cluster_resources.items():
# 检查集群可用资源是否满足任务需求支持NPU/DCU等不同加速卡类型
resource_match = True
for res_type, required in task_resource.items():
available = cluster["available"].get(res_type, 0)
if available < required:
resource_match = False
break
if resource_match:
# 选中集群后更新可用资源(减去任务所需资源)
for res_type, required in task_resource.items():
if res_type in cluster["available"]:
cluster["available"][res_type] -= required
logger.info(f"选中集群 {cluster_id},更新后可用资源: {cluster['available']}")
return cluster_id
logger.warning(f"无满足资源需求的集群 | 任务需求: {task_resource}")
return None
# -------------------------- 新增:数据集上传判断与处理方法 --------------------------
def check_and_handle_dataset(file_location: str, dataset_name: str, cluster_id: str) -> Optional[str]:
"""
检查数据集是否已上传到指定集群,未上传则执行上传
参数:
file_location: 本地文件路径(主键)
dataset_name: 数据集名称
cluster_id: 目标集群ID
返回:
dataset_target_id: 数据集绑定ID已上传则返回现有ID新上传则返回新ID
"""
global dataset_map
# 步骤1: 检查数据集是否已上传
dataset = get_dataset_status(file_location, dataset_name, cluster_id)
# 步骤2: 若未上传,则执行上传
if not dataset or not dataset["is_uploaded"] or cluster_id not in dataset["upload_cluster"]:
dataset = upload_dataset(file_location, dataset_name, cluster_id)
if not dataset:
return None # 上传失败
return dataset["dataset_target_id"]
def get_dataset_status(file_location: str, dataset_name: str, cluster_id: str) -> Optional[Dict]:
"""
检查数据集状态(是否已上传到指定集群)
返回:
数据集信息字典若存在且已上传到指定集群否则返回None
"""
with dataset_lock:
if file_location in dataset_map:
dataset = dataset_map[file_location]
# 验证是否已上传到目标集群
if dataset["is_uploaded"] and cluster_id in dataset["upload_cluster"]:
logger.info(
f"数据集 {dataset_name} 已上传到集群 {cluster_id} | target_id: {dataset['dataset_target_id']}")
return dataset
return None
def upload_dataset(file_location: str, dataset_name: str, cluster_id: str) -> Optional[Dict]:
"""
执行数据集上传流程
返回:
上传成功后的数据集信息字典失败则返回None
"""
dataset_path = os.path.join(file_location, dataset_name)
# 检查本地文件是否存在
if not os.path.exists(dataset_path):
logger.error(f"数据集本地文件不存在 | path: {dataset_path}")
return None
# 计算文件大小(字节)
try:
file_size = os.path.getsize(dataset_path)
except OSError as e:
logger.error(f"获取文件大小失败 | path: {dataset_path} | 错误: {str(e)}")
return None
logger.info(f"开始上传数据集 {dataset_name} 到集群 {cluster_id} | path: {dataset_path}")
try:
# 获取认证Token
token = get_token()
if not token:
logger.error("获取Token失败无法上传数据集")
return None
headers = {'Authorization': f'Bearer {token}'}
package_name = f"dataset-{dataset_name.split('.')[0]}-{uuid4().hex[:6]}" # 生成唯一文件夹名
# 1. 创建数据集文件夹
create_payload = {
"userID": 5,
"name": package_name,
"dataType": "dataset",
"packageID": 0,
"uploadPriority": {"type": "specify", "clusters": [cluster_id]},
"bindingInfo": {
"clusterIDs": [cluster_id],
"name": package_name,
"category": "image",
"type": "dataset",
"imageID": "",
"chip": ["ASCEND"],
"packageID": 0
}
}
create_resp = requests.post(
API_CONFIG["create_package"]["url"],
json=create_payload,
headers=headers,
timeout=API_CONFIG["create_package"]["timeout"]
)
create_resp.raise_for_status()
create_result = create_resp.json()
if create_result.get("code") != "OK":
raise ValueError(f"创建文件夹失败 | 响应: {create_result}")
packageID = create_result["data"]["newPackage"]["packageID"]
logger.info(f"数据集文件夹创建成功 | packageID: {packageID}")
# 2. 上传文件
info_data = {
"userID": 5,
"packageID": packageID,
"loadTo": [3],
"loadToPath": [f"/dataset/5/{package_name}/"]
}
with open(dataset_path, 'rb') as f:
form_data = {"info": (None, json.dumps(info_data)), "files": f}
upload_resp = requests.post(
API_CONFIG["upload_file"]["url"],
files=form_data,
headers=headers,
timeout=API_CONFIG["upload_file"]["timeout"]
)
upload_resp.raise_for_status()
upload_result = upload_resp.json()
if upload_result.get("code") != "OK":
raise ValueError(f"文件上传失败 | 响应: {upload_result}")
object_id = upload_result["data"]["uploadeds"][0]["objectID"]
logger.info(f"数据集文件上传成功 | objectID: {object_id}")
# 3. 通知上传完成
notify_payload = {
"userID": 5,
"packageID": packageID,
"uploadParams": {
"dataType": "dataset",
"uploadInfo": {"type": "local", "localPath": dataset_name, "objectIDs": [object_id]}
}
}
notify_resp = requests.post(
API_CONFIG["notify_upload"]["url"],
json=notify_payload,
headers=headers,
timeout=API_CONFIG["notify_upload"]["timeout"]
)
notify_resp.raise_for_status()
# 4. 绑定到集群
bind_payload = {
"userID": 5,
"info": {"type": "dataset", "packageID": packageID, "clusterIDs": [cluster_id]}
}
bind_resp = requests.post(
API_CONFIG["bind_cluster"]["url"],
json=bind_payload,
headers=headers,
timeout=API_CONFIG["bind_cluster"]["timeout"]
)
bind_resp.raise_for_status()
# 5. 查询绑定IDdataset_target_id
query_resp = requests.post(
API_CONFIG["query_binding"]["url"],
json={"dataType": "dataset", "param": {"userID": 5, "bindingID": -1, "type": "private"}},
headers=headers,
timeout=API_CONFIG["query_binding"]["timeout"]
).json()
if query_resp.get("code") != "OK":
raise ValueError(f"查询绑定ID失败 | 响应: {query_resp}")
dataset_target_id = None
for item in query_resp["data"]["datas"]:
if item["info"]["name"] == package_name:
dataset_target_id = item["ID"]
break
if not dataset_target_id:
raise ValueError(f"未找到数据集 {package_name} 的绑定ID")
# 上传成功,创建并保存数据集信息
dataset = DatasetInfo(
file_location=file_location,
name=dataset_name,
size=file_size,
is_uploaded=True,
upload_cluster=[cluster_id],
upload_time=time.strftime("%Y-%m-%d %H:%M:%S"),
dataset_target_id=dataset_target_id
)
with dataset_lock:
dataset_map[file_location] = dataset
logger.info(f"数据集 {dataset_name} 上传成功 | target_id: {dataset_target_id}")
return dataset
except Exception as e:
logger.error(f"数据集上传失败 | name: {dataset_name} | 错误: {str(e)}", exc_info=True)
return None
# 4. 绑定到集群
bind_payload = {
"userID": 5,
"info": {"type": "dataset", "packageID": packageID, "clusterIDs": [cluster_id]}
}
bind_resp = requests.post(
API_CONFIG["bind_cluster"]["url"],
json=bind_payload,
headers=headers,
timeout=API_CONFIG["bind_cluster"]["timeout"]
)
bind_resp.raise_for_status()
# 5. 查询绑定IDdataset_target_id
query_resp = requests.post(
API_CONFIG["query_binding"]["url"],
json={"dataType": "dataset", "param": {"userID": 5, "bindingID": -1, "type": "private"}},
headers=headers,
timeout=API_CONFIG["query_binding"]["timeout"]
).json()
if query_resp.get("code") != "OK":
raise ValueError(f"查询绑定ID失败 | 响应: {query_resp}")
dataset_target_id = None
for item in query_resp["data"]["datas"]:
if item["info"]["name"] == package_name:
dataset_target_id = item["ID"]
break
if not dataset_target_id:
raise ValueError(f"未找到数据集 {package_name} 的绑定ID")
# 上传成功,更新全局映射
with dataset_lock:
dataset = DatasetInfo(
file_location=file_location,
name=dataset_name,
size=file_size,
is_uploaded=True,
upload_cluster=[cluster_id],
upload_time=time.strftime("%Y-%m-%d %H:%M:%S"),
dataset_target_id=dataset_target_id
)
dataset_map[file_location] = dataset
logger.info(f"数据集 {dataset_name} 上传成功 | target_id: {dataset_target_id}")
return dataset_target_id
except Exception as e:
logger.error(f"数据集上传失败 | name: {dataset_name} | 错误: {str(e)}", exc_info=True)
return None
# -------------------------- API调用方法 --------------------------
def get_token() -> Optional[str]:
"""获取认证Token"""
login_payload = {"username": "admin", "password": "Nudt@123"}
try:
config = API_CONFIG["login"]
response = requests.post(config["url"], json=login_payload, timeout=config["timeout"])
response.raise_for_status()
result = response.json()
if result.get("code") == 200 and "data" in result and "token" in result["data"]:
logger.info("Token获取成功")
return result["data"]["token"]
else:
logger.error(f"Token获取失败 | 响应: {result}")
return None
except requests.exceptions.RequestException as e:
logger.error(f"登录请求异常: {str(e)}", exc_info=True)
return None
def submit_single_task(task: Dict) -> bool:
"""提交单个任务到集群使用模板中的son_code_id"""
token = get_token()
if not token:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = "获取Token失败"
return False
# 获取选择的集群ID
cluster_id = task.get("cluster_id")
if not cluster_id:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = "未指定集群ID"
logger.error(f"[{task['task_name']}] 提交失败未指定集群ID")
return False
# -------------------------- 新增:数据集检查与上传 --------------------------
# 在创建文件夹前先检查数据集状态
dataset_target_id = check_and_handle_dataset(
file_location=task["file_location"],
dataset_name=task["dataset_name"],
cluster_id=cluster_id
)
if not dataset_target_id:
# 数据集上传失败,释放集群资源
with cluster_lock:
if cluster_id in cluster_resources:
for res_type, required in task["resource"].items():
if res_type in cluster_resources[cluster_id]["available"]:
cluster_resources[cluster_id]["available"][res_type] += required
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = "数据集上传失败"
logger.error(f"[{task['task_name']}] 提交失败:数据集处理失败")
return False
# --------------------------------------------------------------------------
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {token}'
}
task_name = task["task_name"]
package_name = task["package_name"]
son_code_id = task["son_code_id"] # 直接使用任务中的子算法ID来自模板
try:
# 第一步创建数据集文件夹使用选中的集群ID- 复用原有逻辑但无需重复上传
config = API_CONFIG["create_package"]
create_payload = {
"userID": 5,
"name": package_name,
"dataType": "dataset",
"packageID": 0,
"uploadPriority": {"type": "specify", "clusters": [cluster_id]},
"bindingInfo": {
"clusterIDs": [cluster_id],
"name": package_name,
"category": "image",
"type": "dataset",
"imageID": "",
"bias": [],
"region": [],
"chip": ["ASCEND"],
"selectedCluster": [],
"modelType": "",
"env": "",
"version": "",
"packageID": 0,
"points": 0
}
}
create_resp = requests.post(config["url"], json=create_payload, headers=headers, timeout=config["timeout"])
create_resp.raise_for_status()
create_result = create_resp.json()
if create_result.get("code") != "OK":
raise ValueError(f"创建文件夹失败 | API返回: {create_result}")
packageID = create_result["data"]["newPackage"]["packageID"]
logger.info(f"[{task_name}] 第一步:创建文件夹成功 | packageID: {packageID} | 集群: {cluster_id}")
# 第二步:通知上传完成(复用已上传的数据集)
config = API_CONFIG["notify_upload"]
notify_payload = {
"userID": 5,
"packageID": packageID,
"uploadParams": {
"dataType": "dataset",
"uploadInfo": {"type": "local", "localPath": task["dataset_name"], "objectIDs": []}
}
}
notify_resp = requests.post(config["url"], json=notify_payload, headers=headers, timeout=config["timeout"])
notify_resp.raise_for_status()
logger.info(f"[{task_name}] 第二步:通知上传完成成功")
# 第三步绑定数据集到集群使用选中的集群ID
config = API_CONFIG["bind_cluster"]
bind_payload = {
"userID": 5,
"info": {"type": "dataset", "packageID": packageID, "clusterIDs": [cluster_id]}
}
bind_resp = requests.post(config["url"], json=bind_payload, headers=headers, timeout=config["timeout"])
bind_resp.raise_for_status()
logger.info(f"[{task_name}] 第三步:数据集绑定集群 {cluster_id} 成功")
# 第四步提交训练任务使用已获取的dataset_target_id
config = API_CONFIG["submit_task"]
task_res = task["resource"]
submit_payload = {
"userID": 5,
"jobSetInfo": {
"jobs": [
{
"localJobID": "1",
"name": task_name,
"description": "自动提交的训练任务",
"type": "AI",
"files": {
"dataset": {"type": "Binding", "bindingID": dataset_target_id}, # 使用检查阶段获取的ID
"model": {"type": "Binding", "bindingID": ""},
"image": {"type": "Image", "imageID": 11}
},
"jobResources": {
"scheduleStrategy": "dataLocality",
"clusters": [
{
"clusterID": cluster_id,
"runtime": {"envs": {}, "params": {}},
"code": {"type": "Binding", "bindingID": son_code_id},
"resources": [
{"type": "CPU", "name": "ARM", "number": task_res["CPU"]},
{"type": "MEMORY", "name": "RAM", "number": task_res["MEMORY"]},
{"type": "MEMORY", "name": "VRAM", "number": 32},
{"type": "STORAGE", "name": "DISK", "number": 886},
{"type": "NPU", "name": "ASCEND910", "number": task_res.get("NPU", 0)}
]
}
]
}
}
]
}
}
response = requests.post(
config["url"],
json=submit_payload,
headers=headers,
timeout=config["timeout"]
)
response.raise_for_status()
submit_resp = response.json()
if submit_resp.get("code") != "OK":
raise ValueError(f"任务提交失败 | API返回: {submit_resp}")
third_party_task_id = submit_resp.get('data', {}).get('jobSetID')
logger.info(f"[{task_name}] 第四步:任务提交至集群 {cluster_id} 成功 | 云际任务ID: {third_party_task_id}")
# 更新任务状态为成功
with task_map_lock:
task["status"] = TASK_STATUS["SUCCEED"]
task["third_party_task_id"] = third_party_task_id
return True
except Exception as e:
error_msg = f"提交失败: {str(e)}"
# 任务失败时释放集群资源
with cluster_lock:
if cluster_id in cluster_resources:
for res_type, required in task["resource"].items():
if res_type in cluster_resources[cluster_id]["available"]:
cluster_resources[cluster_id]["available"][res_type] += required
logger.info(f"任务失败,释放集群 {cluster_id} 资源: {task['resource']}")
with task_map_lock:
task["fail_count"] += 1
if task["fail_count"] >= task["max_fail_threshold"]:
task["status"] = TASK_STATUS["RETRY_EXHAUSTED"]
else:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = error_msg
logger.error(f"[{task_name}] {error_msg}", exc_info=True)
return False
def query_third_party_task_status(third_party_task_id: str) -> Optional[Dict]:
"""
查询云际平台任务状态适配新的list接口
参数:
third_party_task_id: 云际任务ID
返回:
包含状态和时间的字典(如{"status": "Succeeded", "end_time": "2025-07-17T20:57:30+08:00"}
若未找到任务或查询失败返回None
"""
if not third_party_task_id:
logger.warning("云际任务ID为空无法查询状态")
return None
try:
# 配置新的API地址和参数
url = "http://119.45.255.234:30180/pcm/v1/core/task/list"
params = {
"pageNum": 1,
"pageSize": 10,
"type": 1
}
# 获取认证Token
token = get_token()
if not token:
logger.error("获取Token失败无法查询任务状态")
return None
headers = {"Authorization": f"Bearer {token}"} if token else {}
# 发送GET请求
response = requests.get(
url,
params=params,
headers=headers,
timeout=15 # 超时设置
)
response.raise_for_status() # 校验HTTP状态码
result = response.json()
# 校验接口返回状态
if result.get("code") != 200 or "data" not in result or "list" not in result["data"]:
logger.error(f"查询任务状态接口返回异常 | 响应: {result}")
return None
# 遍历任务列表查找目标任务
task_list = result["data"]["list"]
target_task = None
for task in task_list:
if task.get("id") == third_party_task_id:
target_task = task
break
if not target_task:
logger.warning(f"在任务列表中未找到ID为 {third_party_task_id} 的任务")
return None
# 提取状态和时间信息
task_status = target_task.get("status")
end_time = target_task.get("endTime") # 成功/失败时间
# 日志记录查询结果
logger.info(
f"任务 {third_party_task_id} 查询成功 | 状态: {task_status} | "
f"结束时间: {end_time or '未结束'}"
)
return {
"status": task_status,
"end_time": end_time
}
except requests.exceptions.RequestException as e:
logger.error(
f"查询任务状态请求异常 | 任务ID: {third_party_task_id} | 错误: {str(e)}",
exc_info=True
)
return None
except Exception as e:
logger.error(
f"解析任务状态失败 | 任务ID: {third_party_task_id} | 错误: {str(e)}",
exc_info=True
)
return None
# -------------------------- 线程一:任务监控线程 --------------------------
class TaskMonitorThread(threading.Thread):
"""监控线程:专注监控任务状态"""
def __init__(self, check_interval: int = 10):
super().__init__(name="TaskMonitorThread")
self.check_interval = check_interval
self._stop_event = threading.Event()
def run(self) -> None:
logger.info(f"监控线程启动 | 监控间隔: {self.check_interval}")
while not self._stop_event.is_set():
with task_map_lock:
tasks = list(task_map.values())
for task in tasks:
with task_map_lock:
current_status = task["status"]
if current_status == TASK_STATUS["SUBMITTED"]:
continue
elif current_status == TASK_STATUS["SUBMITTING"]:
if not task["third_party_task_id"]:
logger.warning(f"任务 {task['task_name']} 无云际ID跳过状态查询")
continue
# 在TaskMonitorThread的run方法中替换原有调用逻辑
third_party_info = query_third_party_task_status(task["third_party_task_id"])
if third_party_info:
with task_map_lock:
if third_party_info["status"] == "Succeeded":
task["status"] = TASK_STATUS["SUCCEED"]
task["success_time"] = third_party_info["end_time"] # 记录成功时间
logger.info(
f"任务状态更新 | task_name: {task['task_name']} | 提交成功 | "
f"成功时间: {task['success_time']}"
)
elif third_party_info["status"] == "Failed" or "Saved":
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1 # 失败次数加一
task["error_msg"] = f"云际任务执行失败(第{task['fail_count']}次)"
logger.warning(
f"任务状态更新 | task_name: {task['task_name']} | 提交失败 | "
f"失败次数: {task['fail_count']}/{task['max_fail_threshold']}"
)
if self._check_all_tasks_completed():
logger.info("所有任务已完成")
self.stop()
self._stop_event.wait(self.check_interval)
logger.info("监控线程结束")
def _check_all_tasks_completed(self) -> bool:
"""检查是否所有任务已完成"""
with task_map_lock:
for task in task_map.values():
if task["status"] in [TASK_STATUS["SUBMITTED"], TASK_STATUS["SUBMITTING"]]:
return False
if task["status"] == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]:
return False
return True
def stop(self) -> None:
self._stop_event.set()
# -------------------------- 线程二:任务提交线程 --------------------------
class TaskSubmitThread(threading.Thread):
"""提交线程:按状态判断是否提交"""
def __init__(self, max_workers: int = 3):
super().__init__(name="TaskSubmitThread")
self.max_workers = max_workers
self._stop_event = threading.Event()
def run(self) -> None:
logger.info(f"提交线程启动 | 并发数: {self.max_workers}")
while not self._stop_event.is_set():
# 筛选待提交任务
with task_map_lock:
pending_tasks = []
for task in task_map.values():
status = task["status"]
if status == TASK_STATUS["SUBMITTED"]:
pending_tasks.append(task)
elif status == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]:
pending_tasks.append(task)
elif status == TASK_STATUS["FAILED"]:
logger.info(
f"任务 {task['task_name']} 失败次数超阈值,停止提交")
# 并发提交任务
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {executor.submit(self.commit_task, task): task for task in pending_tasks}
for future in concurrent.futures.as_completed(futures):
task = futures[future]
try:
future.result()
except Exception as e:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1
task["error_msg"] = f"提交过程异常: {str(e)}"
logger.error(f"任务提交异常 | task_name: {task['task_name']} | 错误: {str(e)}")
if self._check_all_tasks_completed():
logger.info("所有任务已完成,提交线程退出")
self.stop()
break
if not pending_tasks:
logger.info("无待提交任务,等待下次检查")
self._stop_event.wait(5)
logger.info("提交线程结束")
def commit_task(self, task: Dict) -> None:
"""提交任务入口:先选集群,再提交"""
cluster_id = select_cluster(task["resource"])
if not cluster_id:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1
task["error_msg"] = "无可用集群"
logger.error(f"[{task['task_name']}] 提交失败:无可用集群")
return
# 标记为提交中
with task_map_lock:
task["status"] = TASK_STATUS["SUBMITTING"]
task["cluster_id"] = cluster_id
logger.info(f"[{task['task_name']}] 开始提交至集群 {cluster_id}")
# 执行提交
submit_success = submit_single_task(task)
if not submit_success:
logger.warning(f"[{task['task_name']}] 提交失败,等待重试(当前失败次数:{task['fail_count']}")
def _check_all_tasks_completed(self) -> bool:
"""检查是否所有任务已完成"""
with task_map_lock:
for task in task_map.values():
if task["status"] in [TASK_STATUS["SUBMITTED"], TASK_STATUS["SUBMITTING"]]:
return False
if task["status"] == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]:
return False
return True
def stop(self) -> None:
self._stop_event.set()
# -------------------------- 主程序 --------------------------
if __name__ == "__main__":
task_templates = generate_task_templates()
load_tasks_to_queue(task_templates)
monitor_thread = TaskMonitorThread(check_interval=10)
monitor_thread.start()
submit_thread = TaskSubmitThread(max_workers=1)
submit_thread.start()
monitor_thread.join()
submit_thread.join()
logger.info("所有任务处理完毕,程序退出")