SchedulingSimulator/commit_task_jointcloud_isda...

846 lines
35 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import concurrent.futures
import time
import logging
import threading
from uuid import uuid4
from typing import Dict, List, Optional
import requests
import os
import json
import yaml # 用于处理YAML文件
# -------------------------- 全局配置与常量定义 --------------------------
# 日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# 任务状态定义
TASK_STATUS = {
"SUBMITTED": "待提交", # 初始状态
"SUBMITTING": "提交中", # 提交过程中
"SUCCEED": "提交成功", # 云际确认成功
"FAILED": "提交失败", # 云际确认失败
"RETRY_EXHAUSTED": "重试耗尽" # 超过最大失败次数
}
# 全局任务字典key=target_idvalue=任务详情)
task_map: Dict[str, Dict] = {}
task_map_lock = threading.Lock() # 任务字典线程锁
# 全局数据集映射key=file_locationvalue=DatasetInfo实例
dataset_map: Dict[str, Dict] = {}
dataset_lock = threading.Lock() # 数据集映射线程锁
# 子算法-集群-数据集映射从YAML加载
ALGORITHM_MAPPING: Dict[int, Dict] = {} # 结构: {son_code_id: {"clusters": [], "file_location": ""}}
# 任务模板从YAML加载
TASK_TEMPLATES: List[Dict] = []
# API配置
API_CONFIG = {
"login": {
"url": "http://119.45.255.234:30180/jcc-admin/admin/login",
"timeout": 10
},
"create_package": {
"url": "http://119.45.255.234:30180/jsm/jobSet/createPackage",
"timeout": 15
},
"upload_file": {
"url": "http://121.36.5.116:32010/object/upload",
"timeout": 3000
},
"notify_upload": {
"url": "http://119.45.255.234:30180/jsm/jobSet/notifyUploaded",
"timeout": 15
},
"bind_cluster": {
"url": "http://119.45.255.234:30180/jsm/jobSet/binding",
"timeout": 15
},
"query_binding": {
"url": "http://119.45.255.234:30180/jsm/jobSet/queryBinding",
"timeout": 15
},
"submit_task": {
"url": "http://119.45.255.234:30180/jsm/jobSet/submit",
"timeout": 100
},
"task_details": {
"url": "http://119.45.255.234:30180/pcm/v1/core/task/details",
"timeout": 15
}
}
# 集群资源配置key=集群IDvalue=总资源/可用资源)
cluster_resources: Dict[str, Dict] = {
"1790300942428540928": { # modelarts集群
"total": {"CPU": 1024, "MEMORY": 2048, "NPU": 56},
"available": {"CPU": 1024, "MEMORY": 2048, "NPU": 56}
},
"1790300942428540929": { # 新增集群示例
"total": {"CPU": 512, "MEMORY": 1024, "NPU": 32},
"available": {"CPU": 512, "MEMORY": 1024, "NPU": 32}
}
}
cluster_lock = threading.Lock() # 集群资源线程锁
# -------------------------- 数据结构定义 --------------------------
class DatasetInfo(dict):
"""数据集信息结构"""
def __init__(self, file_location: str, name: str, size: float, **kwargs):
super().__init__()
self["file_location"] = file_location # 本地路径(主键)
self["id"] = kwargs.get("id", str(uuid4())) # 数据集唯一标识
self["name"] = name # 数据集名称
self["size"] = size # 大小(字节)
self["is_uploaded"] = kwargs.get("is_uploaded", False) # 是否已上传
self["upload_cluster"] = kwargs.get("upload_cluster", []) # 上传的集群
self["upload_time"] = kwargs.get("upload_time") # 上传时间
self["description"] = kwargs.get("description") # 描述
self["dataset_target_id"] = kwargs.get("dataset_target_id") # 绑定ID
class TaskInfo(dict):
"""任务信息结构"""
def __init__(self, task_name: str, dataset_name: str, son_code_id: int, resource: Dict, **kwargs):
super().__init__()
self["target_id"] = kwargs.get("target_id", str(uuid4())) # 任务唯一ID
self["task_name"] = task_name # 任务名称
self["package_name"] = kwargs.get("package_name", f"{task_name.lower()}-pkg") # 文件夹名称
self["dataset_name"] = dataset_name # 关联数据集名称
self["son_code_id"] = son_code_id # 子算法ID
self["resource"] = resource # 资源需求CPU/MEMORY/NPU等
self["status"] = TASK_STATUS["SUBMITTED"] # 初始状态:待提交
self["submit_time"] = kwargs.get("submit_time", time.strftime("%Y-%m-%d %H:%M:%S")) # 提交时间
self["success_time"] = None # 成功时间(成功时填充)
self["third_party_task_id"] = "" # 云际任务ID提交后填充
self["file_location"] = kwargs.get("file_location", "") # 本地文件路径(从映射获取)
self["error_msg"] = "" # 错误信息
self["fail_count"] = 0 # 失败次数
self["max_fail_threshold"] = kwargs.get("max_fail_threshold", 3) # 最大失败阈值
self["cluster_id"] = "" # 提交的集群ID从映射获取
# -------------------------- YAML配置文件处理 --------------------------
def load_yaml_config(yaml_path: str = "sonCode_cluster__mapping.yaml") -> None:
"""加载YAML配置文件包含算法映射和任务模板"""
# 加载配置文件
try:
with open(str, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
ALGORITHM_MAPPING = data.get("algorithm_mapping", {})
TASK_TEMPLATES = data.get("task_templates", [])
logger.info(
f"成功加载配置文件 | 子算法映射: {len(ALGORITHM_MAPPING)} 条 | 任务模板: {len(TASK_TEMPLATES)}")
except Exception as e:
logger.error(f"加载配置文件失败: {str(e)}", exc_info=True)
# -------------------------- 工具方法 --------------------------
def generate_task_templates() -> List[Dict]:
"""从全局变量返回任务模板实际从YAML加载"""
return TASK_TEMPLATES
def load_tasks_to_queue(templates: List[Dict]) -> None:
"""将任务静态数据加载到任务队列从映射获取file_location"""
global task_map
with task_map_lock:
task_map.clear()
for idx, template in enumerate(templates):
try:
# 检查必填字段
required_fields = ["task_name_template", "prefix", "dataset_name", "son_code_id", "resource"]
missing_fields = [f for f in required_fields if f not in template]
if missing_fields:
logger.warning(f"跳过无效任务模板(索引{idx}):缺少字段 {missing_fields}")
continue
# 从映射获取file_location
son_code_id = template["son_code_id"]
mapping = ALGORITHM_MAPPING.get(son_code_id)
if not mapping:
logger.warning(f"子算法ID {son_code_id} 无映射配置,跳过任务模板(索引{idx}")
continue
file_location = mapping["file_location"]
# 生成任务名称
task_name = template["task_name_template"].format(prefix=template["prefix"])
# 创建任务实例
task = TaskInfo(
task_name=task_name,
dataset_name=template["dataset_name"],
son_code_id=son_code_id,
resource=template["resource"],
file_location=file_location # 从映射填充
)
task_map[task["target_id"]] = task
logger.info(
f"任务入队 | task_name: {task_name} | 子算法ID: {son_code_id} | 数据集路径: {file_location}")
except Exception as e:
logger.error(f"加载任务模板失败(索引{idx}{str(e)}")
logger.info(f"任务队列加载完成 | 共 {len(task_map)} 个有效任务")
def select_cluster(task_resource: Dict, son_code_id: int) -> Optional[str]:
"""根据任务资源需求和子算法ID选择合适的集群优先从映射中选择"""
# 1. 从映射获取该子算法支持的集群
mapping = ALGORITHM_MAPPING.get(son_code_id)
if not mapping:
logger.warning(f"子算法ID {son_code_id} 无集群映射,尝试所有集群")
candidate_clusters = list(cluster_resources.keys())
else:
candidate_clusters = mapping["clusters"]
with cluster_lock:
# 2. 检查候选集群是否满足资源需求
for cluster_id in candidate_clusters:
if cluster_id not in cluster_resources:
logger.warning(f"映射中集群 {cluster_id} 不存在于资源配置中,跳过")
continue
cluster = cluster_resources[cluster_id]
# 检查资源是否满足
resource_match = True
for res_type, required in task_resource.items():
available = cluster["available"].get(res_type, 0)
if available < required:
resource_match = False
break
if resource_match:
# 占用资源
for res_type, required in task_resource.items():
if res_type in cluster["available"]:
cluster["available"][res_type] -= required
logger.info(f"选中集群 {cluster_id}(子算法 {son_code_id} 映射)| 更新后可用资源: {cluster['available']}")
return cluster_id
# 3. 若映射中的集群不满足,尝试其他集群
all_clusters = list(cluster_resources.keys())
for cluster_id in all_clusters:
if cluster_id in candidate_clusters:
continue # 已检查过
cluster = cluster_resources[cluster_id]
resource_match = True
for res_type, required in task_resource.items():
available = cluster["available"].get(res_type, 0)
if available < required:
resource_match = False
break
if resource_match:
for res_type, required in task_resource.items():
if res_type in cluster["available"]:
cluster["available"][res_type] -= required
logger.info(f"选中集群 {cluster_id}(非映射)| 更新后可用资源: {cluster['available']}")
return cluster_id
logger.warning(f"无满足资源需求的集群 | 任务需求: {task_resource} | 子算法: {son_code_id}")
return None
# -------------------------- 数据集上传判断与处理方法 --------------------------
def check_and_handle_dataset(file_location: str, dataset_name: str, cluster_id: str) -> Optional[str]:
"""检查数据集是否已上传到指定集群,未上传则执行上传"""
global dataset_map
# 步骤1: 检查数据集是否已上传
dataset = get_dataset_status(file_location, dataset_name, cluster_id)
# 步骤2: 若未上传,则执行上传
if not dataset or not dataset["is_uploaded"] or cluster_id not in dataset["upload_cluster"]:
dataset = upload_dataset(file_location, dataset_name, cluster_id)
if not dataset:
return None # 上传失败
return dataset["dataset_target_id"]
def get_dataset_status(file_location: str, dataset_name: str, cluster_id: str) -> Optional[Dict]:
"""检查数据集状态(是否已上传到指定集群)"""
with dataset_lock:
if file_location in dataset_map:
dataset = dataset_map[file_location]
# 验证是否已上传到目标集群
if dataset["is_uploaded"] and cluster_id in dataset["upload_cluster"]:
logger.info(
f"数据集 {dataset_name} 已上传到集群 {cluster_id} | target_id: {dataset['dataset_target_id']}")
return dataset
return None
def upload_dataset(file_location: str, dataset_name: str, cluster_id: str) -> Optional[Dict]:
"""执行数据集上传流程"""
dataset_path = os.path.join(file_location, dataset_name)
# 检查本地文件是否存在
if not os.path.exists(dataset_path):
logger.error(f"数据集本地文件不存在 | path: {dataset_path}")
return None
# 计算文件大小(字节)
try:
file_size = os.path.getsize(dataset_path)
except OSError as e:
logger.error(f"获取文件大小失败 | path: {dataset_path} | 错误: {str(e)}")
return None
logger.info(f"开始上传数据集 {dataset_name} 到集群 {cluster_id} | path: {dataset_path}")
try:
# 获取认证Token
token = get_token()
if not token:
logger.error("获取Token失败无法上传数据集")
return None
headers = {'Authorization': f'Bearer {token}'}
package_name = f"dataset-{dataset_name.split('.')[0]}-{uuid4().hex[:6]}" # 生成唯一文件夹名
# 1. 创建数据集文件夹
create_payload = {
"userID": 5,
"name": package_name,
"dataType": "dataset",
"packageID": 0,
"uploadPriority": {"type": "specify", "clusters": [cluster_id]},
"bindingInfo": {
"clusterIDs": [cluster_id],
"name": package_name,
"category": "image",
"type": "dataset",
"imageID": "",
"chip": ["ASCEND"],
"packageID": 0
}
}
create_resp = requests.post(
API_CONFIG["create_package"]["url"],
json=create_payload,
headers=headers,
timeout=API_CONFIG["create_package"]["timeout"]
)
create_resp.raise_for_status()
create_result = create_resp.json()
if create_result.get("code") != "OK":
raise ValueError(f"创建文件夹失败 | 响应: {create_result}")
packageID = create_result["data"]["newPackage"]["packageID"]
logger.info(f"数据集文件夹创建成功 | packageID: {packageID}")
# 2. 上传文件
info_data = {
"userID": 5,
"packageID": packageID,
"loadTo": [3],
"loadToPath": [f"/dataset/5/{package_name}/"]
}
with open(dataset_path, 'rb') as f:
form_data = {"info": (None, json.dumps(info_data)), "files": f}
upload_resp = requests.post(
API_CONFIG["upload_file"]["url"],
files=form_data,
headers=headers,
timeout=API_CONFIG["upload_file"]["timeout"]
)
upload_resp.raise_for_status()
upload_result = upload_resp.json()
if upload_result.get("code") != "OK":
raise ValueError(f"文件上传失败 | 响应: {upload_result}")
object_id = upload_result["data"]["uploadeds"][0]["objectID"]
logger.info(f"数据集文件上传成功 | objectID: {object_id}")
# 3. 通知上传完成
notify_payload = {
"userID": 5,
"packageID": packageID,
"uploadParams": {
"dataType": "dataset",
"uploadInfo": {"type": "local", "localPath": dataset_name, "objectIDs": [object_id]}
}
}
notify_resp = requests.post(
API_CONFIG["notify_upload"]["url"],
json=notify_payload,
headers=headers,
timeout=API_CONFIG["notify_upload"]["timeout"]
)
notify_resp.raise_for_status()
# 4. 绑定到集群
bind_payload = {
"userID": 5,
"info": {"type": "dataset", "packageID": packageID, "clusterIDs": [cluster_id]}
}
bind_resp = requests.post(
API_CONFIG["bind_cluster"]["url"],
json=bind_payload,
headers=headers,
timeout=API_CONFIG["bind_cluster"]["timeout"]
)
bind_resp.raise_for_status()
# 5. 查询绑定IDdataset_target_id
query_resp = requests.post(
API_CONFIG["query_binding"]["url"],
json={"dataType": "dataset", "param": {"userID": 5, "bindingID": -1, "type": "private"}},
headers=headers,
timeout=API_CONFIG["query_binding"]["timeout"]
).json()
if query_resp.get("code") != "OK":
raise ValueError(f"查询绑定ID失败 | 响应: {query_resp}")
dataset_target_id = None
for item in query_resp["data"]["datas"]:
if item["info"]["name"] == package_name:
dataset_target_id = item["ID"]
break
if not dataset_target_id:
raise ValueError(f"未找到数据集 {package_name} 的绑定ID")
# 上传成功,创建并保存数据集信息
dataset = DatasetInfo(
file_location=file_location,
name=dataset_name,
size=file_size,
is_uploaded=True,
upload_cluster=[cluster_id],
upload_time=time.strftime("%Y-%m-%d %H:%M:%S"),
dataset_target_id=dataset_target_id
)
with dataset_lock:
dataset_map[file_location] = dataset
logger.info(f"数据集 {dataset_name} 上传成功 | target_id: {dataset_target_id}")
return dataset
except Exception as e:
logger.error(f"数据集上传失败 | name: {dataset_name} | 错误: {str(e)}", exc_info=True)
return None
# -------------------------- API调用方法 --------------------------
def get_token() -> Optional[str]:
"""获取认证Token"""
login_payload = {"username": "admin", "password": "Nudt@123"}
try:
config = API_CONFIG["login"]
response = requests.post(config["url"], json=login_payload, timeout=config["timeout"])
response.raise_for_status()
result = response.json()
if result.get("code") == 200 and "data" in result and "token" in result["data"]:
logger.info("Token获取成功")
return result["data"]["token"]
else:
logger.error(f"Token获取失败 | 响应: {result}")
return None
except requests.exceptions.RequestException as e:
logger.error(f"登录请求异常: {str(e)}", exc_info=True)
return None
def submit_single_task(task: Dict) -> bool:
"""提交单个任务到集群"""
token = get_token()
if not token:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = "获取Token失败"
return False
# 获取选择的集群ID
cluster_id = task.get("cluster_id")
if not cluster_id:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = "未指定集群ID"
logger.error(f"[{task['task_name']}] 提交失败未指定集群ID")
return False
# 数据集检查与上传
dataset_target_id = check_and_handle_dataset(
file_location=task["file_location"],
dataset_name=task["dataset_name"],
cluster_id=cluster_id
)
if not dataset_target_id:
# 数据集上传失败,释放集群资源
with cluster_lock:
if cluster_id in cluster_resources:
for res_type, required in task["resource"].items():
if res_type in cluster_resources[cluster_id]["available"]:
cluster_resources[cluster_id]["available"][res_type] += required
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = "数据集上传失败"
logger.error(f"[{task['task_name']}] 提交失败:数据集处理失败")
return False
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {token}'
}
task_name = task["task_name"]
package_name = task["package_name"]
son_code_id = task["son_code_id"] # 子算法ID
try:
# 1. 创建数据集文件夹
config = API_CONFIG["create_package"]
create_payload = {
"userID": 5,
"name": package_name,
"dataType": "dataset",
"packageID": 0,
"uploadPriority": {"type": "specify", "clusters": [cluster_id]},
"bindingInfo": {
"clusterIDs": [cluster_id],
"name": package_name,
"category": "image",
"type": "dataset",
"imageID": "",
"bias": [],
"region": [],
"chip": ["ASCEND"],
"selectedCluster": [],
"modelType": "",
"env": "",
"version": "",
"packageID": 0,
"points": 0
}
}
create_resp = requests.post(config["url"], json=create_payload, headers=headers, timeout=config["timeout"])
create_resp.raise_for_status()
create_result = create_resp.json()
if create_result.get("code") != "OK":
raise ValueError(f"创建文件夹失败 | API返回: {create_result}")
packageID = create_result["data"]["newPackage"]["packageID"]
logger.info(f"[{task_name}] 第一步:创建文件夹成功 | packageID: {packageID} | 集群: {cluster_id}")
# 2. 通知上传完成
config = API_CONFIG["notify_upload"]
notify_payload = {
"userID": 5,
"packageID": packageID,
"uploadParams": {
"dataType": "dataset",
"uploadInfo": {"type": "local", "localPath": task["dataset_name"], "objectIDs": []}
}
}
notify_resp = requests.post(config["url"], json=notify_payload, headers=headers, timeout=config["timeout"])
notify_resp.raise_for_status()
logger.info(f"[{task_name}] 第二步:通知上传完成成功")
# 3. 绑定数据集到集群
config = API_CONFIG["bind_cluster"]
bind_payload = {
"userID": 5,
"info": {"type": "dataset", "packageID": packageID, "clusterIDs": [cluster_id]}
}
bind_resp = requests.post(config["url"], json=bind_payload, headers=headers, timeout=config["timeout"])
bind_resp.raise_for_status()
logger.info(f"[{task_name}] 第三步:数据集绑定集群 {cluster_id} 成功")
# 4. 提交训练任务
config = API_CONFIG["submit_task"]
task_res = task["resource"]
submit_payload = {
"userID": 5,
"jobSetInfo": {
"jobs": [
{
"localJobID": "1",
"name": task_name,
"description": "自动提交的训练任务",
"type": "AI",
"files": {
"dataset": {"type": "Binding", "bindingID": dataset_target_id},
"model": {"type": "Binding", "bindingID": ""},
"image": {"type": "Image", "imageID": 11}
},
"jobResources": {
"scheduleStrategy": "dataLocality",
"clusters": [
{
"clusterID": cluster_id,
"runtime": {"envs": {}, "params": {}},
"code": {"type": "Binding", "bindingID": son_code_id},
"resources": [
{"type": "CPU", "name": "ARM", "number": task_res["CPU"]},
{"type": "MEMORY", "name": "RAM", "number": task_res["MEMORY"]},
{"type": "MEMORY", "name": "VRAM", "number": 32},
{"type": "STORAGE", "name": "DISK", "number": 886},
{"type": "NPU", "name": "ASCEND910", "number": task_res.get("NPU", 0)}
]
}
]
}
}
]
}
}
response = requests.post(
config["url"],
json=submit_payload,
headers=headers,
timeout=config["timeout"]
)
response.raise_for_status()
submit_resp = response.json()
if submit_resp.get("code") != "OK":
raise ValueError(f"任务提交失败 | API返回: {submit_resp}")
third_party_task_id = submit_resp.get('data', {}).get('jobSetID')
logger.info(f"[{task_name}] 第四步:任务提交至集群 {cluster_id} 成功 | 云际任务ID: {third_party_task_id}")
# 更新任务状态为成功
with task_map_lock:
task["status"] = TASK_STATUS["SUCCEED"]
task["third_party_task_id"] = third_party_task_id
return True
except Exception as e:
error_msg = f"提交失败: {str(e)}"
# 任务失败时释放集群资源
with cluster_lock:
if cluster_id in cluster_resources:
for res_type, required in task["resource"].items():
if res_type in cluster_resources[cluster_id]["available"]:
cluster_resources[cluster_id]["available"][res_type] += required
logger.info(f"任务失败,释放集群 {cluster_id} 资源: {task['resource']}")
with task_map_lock:
task["fail_count"] += 1
if task["fail_count"] >= task["max_fail_threshold"]:
task["status"] = TASK_STATUS["RETRY_EXHAUSTED"]
else:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = error_msg
logger.error(f"[{task_name}] {error_msg}", exc_info=True)
return False
def query_third_party_task_status(third_party_task_id: str) -> Optional[Dict]:
"""查询云际平台任务状态"""
if not third_party_task_id:
logger.warning("云际任务ID为空无法查询状态")
return None
try:
url = "http://119.45.255.234:30180/pcm/v1/core/task/list"
params = {
"pageNum": 1,
"pageSize": 10,
"type": 1
}
token = get_token()
if not token:
logger.error("获取Token失败无法查询任务状态")
return None
headers = {"Authorization": f"Bearer {token}"} if token else {}
response = requests.get(
url,
params=params,
headers=headers,
timeout=15
)
response.raise_for_status()
result = response.json()
if result.get("code") != 200 or "data" not in result or "list" not in result["data"]:
logger.error(f"查询任务状态接口返回异常 | 响应: {result}")
return None
task_list = result["data"]["list"]
target_task = None
for task in task_list:
if task.get("id") == third_party_task_id:
target_task = task
break
if not target_task:
logger.warning(f"未找到任务 {third_party_task_id}")
return None
task_status = target_task.get("status")
end_time = target_task.get("endTime")
logger.info(f"任务 {third_party_task_id} 状态: {task_status} | 结束时间: {end_time or '未结束'}")
return {
"status": task_status,
"end_time": end_time
}
except Exception as e:
logger.error(f"查询任务状态失败 | 任务ID: {third_party_task_id} | 错误: {str(e)}", exc_info=True)
return None
# -------------------------- 线程一:任务监控线程 --------------------------
class TaskMonitorThread(threading.Thread):
"""监控线程:专注监控任务状态"""
def __init__(self, check_interval: int = 10):
super().__init__(name="TaskMonitorThread")
self.check_interval = check_interval
self._stop_event = threading.Event()
def run(self) -> None:
logger.info(f"监控线程启动 | 监控间隔: {self.check_interval}")
while not self._stop_event.is_set():
with task_map_lock:
tasks = list(task_map.values())
for task in tasks:
with task_map_lock:
current_status = task["status"]
if current_status in [TASK_STATUS["SUBMITTED"], TASK_STATUS["RETRY_EXHAUSTED"]]:
continue
if current_status == TASK_STATUS["SUBMITTING"] and task["third_party_task_id"]:
third_party_info = query_third_party_task_status(task["third_party_task_id"])
if third_party_info:
with task_map_lock:
if third_party_info["status"] == "Succeeded":
task["status"] = TASK_STATUS["SUCCEED"]
task["success_time"] = third_party_info["end_time"]
logger.info(f"任务 {task['task_name']} 成功 | 时间: {task['success_time']}")
elif third_party_info["status"] in ["Failed", "Saved"]:
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1
task["error_msg"] = f"云际任务失败(第{task['fail_count']}次)"
logger.warning(f"任务 {task['task_name']} 失败 | 次数: {task['fail_count']}")
if self._check_all_completed():
logger.info("所有任务已完成")
self.stop()
self._stop_event.wait(self.check_interval)
logger.info("监控线程结束")
def _check_all_completed(self) -> bool:
"""检查所有任务是否完成"""
with task_map_lock:
for task in task_map.values():
if task["status"] in [TASK_STATUS["SUBMITTED"], TASK_STATUS["SUBMITTING"]]:
return False
if task["status"] == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]:
return False
return True
def stop(self) -> None:
self._stop_event.set()
# -------------------------- 线程二:任务提交线程 --------------------------
class TaskSubmitThread(threading.Thread):
"""提交线程:按状态判断是否提交"""
def __init__(self, max_workers: int = 3):
super().__init__(name="TaskSubmitThread")
self.max_workers = max_workers
self._stop_event = threading.Event()
def run(self) -> None:
logger.info(f"提交线程启动 | 并发数: {self.max_workers}")
while not self._stop_event.is_set():
# 筛选待提交任务
with task_map_lock:
pending_tasks = []
for task in task_map.values():
status = task["status"]
if status == TASK_STATUS["SUBMITTED"]:
pending_tasks.append(task)
elif status == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]:
pending_tasks.append(task)
elif status == TASK_STATUS["FAILED"]:
logger.info(f"任务 {task['task_name']} 失败次数超限,停止提交")
# 并发提交
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {executor.submit(self.commit_task, task): task for task in pending_tasks}
for future in concurrent.futures.as_completed(futures):
task = futures[future]
try:
future.result()
except Exception as e:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1
task["error_msg"] = f"提交异常: {str(e)}"
logger.error(f"任务 {task['task_name']} 提交异常 | 错误: {str(e)}")
if self._check_all_completed():
logger.info("所有任务已完成,提交线程退出")
self.stop()
break
if not pending_tasks:
logger.info("无待提交任务等待5秒")
self._stop_event.wait(5)
logger.info("提交线程结束")
def commit_task(self, task: Dict) -> None:
"""提交任务入口:基于映射选择集群"""
son_code_id = task["son_code_id"]
cluster_id = select_cluster(task["resource"], son_code_id)
if not cluster_id:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1
task["error_msg"] = "无可用集群"
logger.error(f"任务 {task['task_name']} 提交失败:无可用集群")
return
# 标记为提交中
with task_map_lock:
task["status"] = TASK_STATUS["SUBMITTING"]
task["cluster_id"] = cluster_id
logger.info(f"任务 {task['task_name']} 开始提交至集群 {cluster_id}(子算法 {son_code_id}")
# 执行提交
submit_success = submit_single_task(task)
if not submit_success:
logger.warning(f"任务 {task['task_name']} 提交失败,等待重试(当前次数:{task['fail_count']}")
def _check_all_completed(self) -> bool:
"""检查所有任务是否完成"""
with task_map_lock:
for task in task_map.values():
if task["status"] in [TASK_STATUS["SUBMITTED"], TASK_STATUS["SUBMITTING"]]:
return False
if task["status"] == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]:
return False
return True
def stop(self) -> None:
self._stop_event.set()
# -------------------------- 主程序 --------------------------
if __name__ == "__main__":
# 加载YAML配置包含算法映射和任务模板
load_yaml_config()
# 生成任务模板并加载到队列
task_templates = generate_task_templates()
load_tasks_to_queue(task_templates)
# 启动线程
monitor_thread = TaskMonitorThread(check_interval=10)
monitor_thread.start()
submit_thread = TaskSubmitThread(max_workers=3)
submit_thread.start()
# 等待线程完成
monitor_thread.join()
submit_thread.join()
logger.info("所有任务处理完毕,程序退出")