SchedulingSimulator/commit_task_jointcloud_0716.py

799 lines
33 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import concurrent.futures
import time
import logging
import threading
from uuid import uuid4
from typing import Dict, List, Optional
import requests
import os
import json
# -------------------------- 全局配置与常量定义 --------------------------
# 日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# 任务状态定义
TASK_STATUS = {
"SUBMITTED": "待提交", # 初始状态
"SUBMITTING": "提交中", # 提交过程中
"SUCCEED": "提交成功", # 第三方确认成功
"FAILED": "提交失败" # 第三方确认失败
}
# 全局任务字典key=target_idvalue=任务详情)
task_map: Dict[str, Dict] = {}
task_map_lock = threading.Lock() # 任务字典线程锁
# API配置新增任务详情查询接口
API_CONFIG = {
"login": {
"url": "http://119.45.255.234:30180/jcc-admin/admin/login",
"timeout": 10
},
"create_package": {
"url": "http://119.45.255.234:30180/jsm/jobSet/createPackage",
"timeout": 15
},
"upload_file": {
"url": "http://119.45.255.234:30180/jcs/object/upload",
"timeout": 300
},
"notify_upload": {
"url": "http://119.45.255.234:30180/jsm/jobSet/notifyUploaded",
"timeout": 15
},
"bind_cluster": {
"url": "http://119.45.255.234:30180/jsm/jobSet/binding",
"timeout": 15
},
"query_binding": {
"url": "http://119.45.255.234:30180/jsm/jobSet/queryBinding",
"timeout": 15
},
"submit_task": {
"url": "http://119.45.255.234:30180/jsm/jobSet/submit",
"timeout": 15
},
"task_details": { # 新增任务详情查询接口配置
"url": "http://119.45.255.234:30180/pcm/v1/core/task/details",
"timeout": 15
}
}
# 集群资源配置key=集群IDvalue=总资源/可用资源)
cluster_resources: Dict[str, Dict] = {
"1790300942428540928": { # modelarts集群
"total": {"CPU": 96, "MEMORY": 1024, "NPU": 2},
"available": {"CPU": 48, "MEMORY": 512, "NPU": 1}
},
"1865927992266461184": { # openi集群
"total": {"CPU": 48, "MEMORY": 512, "DCU": 1},
"available": {"CPU": 24, "MEMORY": 256, "DCU": 1}
},
"1865927992266462181": { # 章鱼集群
"total": {"CPU": 48, "MEMORY": 512, "DCU": 1},
"available": {"CPU": 24, "MEMORY": 256, "DCU": 1}
},
"1777240145309732864": { # 曙光集群
"total": {"CPU": 48, "MEMORY": 512, "NPU": 1},
"available": {"CPU": 24, "MEMORY": 256, "NPU": 1}
},
}
cluster_lock = threading.Lock() # 集群资源线程锁
# -------------------------- 数据结构定义 --------------------------
class DatasetInfo(dict):
"""数据集信息结构"""
def __init__(self, file_location: str, name: str, size: float, **kwargs):
super().__init__()
self["file_location"] = file_location # 本地路径(主键)
self["id"] = kwargs.get("id", str(uuid4())) # 数据集唯一标识
self["name"] = name # 数据集名称
self["size"] = size # 大小(字节)
self["is_uploaded"] = kwargs.get("is_uploaded", False) # 是否已上传
self["upload_cluster"] = kwargs.get("upload_cluster", []) # 上传的集群
self["upload_time"] = kwargs.get("upload_time") # 上传时间
self["description"] = kwargs.get("description") # 描述
class AlgorithmInfo(dict):
"""算法信息结构"""
def __init__(self, cluster: str, id: str, name: str, **kwargs):
super().__init__()
self["cluster"] = cluster # 所属集群
self["id"] = id # 算法唯一标识
self["son_id"] = kwargs.get("son_id", "") # 子算法ID
self["name"] = name # 算法名称
class TaskInfo(dict):
"""任务信息结构新增success_time字段记录成功时间"""
def __init__(self, task_name: str, dataset_name: str, code_id: str, resource: Dict, **kwargs):
super().__init__()
self["target_id"] = kwargs.get("target_id", str(uuid4())) # 任务唯一ID
self["task_name"] = task_name # 任务名称
self["package_name"] = kwargs.get("package_name", f"{task_name.lower()}-pkg") # 文件夹名称
self["dataset_name"] = dataset_name # 关联数据集名称
self["code_id"] = code_id # 算法ID
self["son_code_id"] = "" # 子算法ID提交时填充
self["resource"] = resource # 资源需求CPU/MEMORY/NPU等
self["status"] = TASK_STATUS["SUBMITTED"] # 初始状态:待提交
self["submit_time"] = kwargs.get("submit_time", time.strftime("%Y-%m-%d %H:%M:%S")) # 提交时间
self["success_time"] = None # 成功时间(成功时填充)
self["third_party_task_id"] = "" # 第三方任务ID提交后填充
self["file_location"] = kwargs.get("file_location", "") # 本地文件路径
self["error_msg"] = "" # 错误信息
self["fail_count"] = 0 # 失败次数原retry_count改为fail_count更贴合语义
self["max_fail_threshold"] = kwargs.get("max_fail_threshold", 3) # 最大失败阈值
self["cluster_id"] = "" # 提交的集群ID提交时填充
# -------------------------- 工具方法 --------------------------
def generate_task_templates() -> List[Dict]:
"""生成任务静态数据模板"""
return [
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AA",
"dataset_name": "data1.zip",
"code_id": "1164",
"file_location": "D:/数据集/cnn数据集/data1/",
"resource": {"CPU": 24, "MEMORY": 256, "NPU": 1}
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AB",
"dataset_name": "cifar-10-python.tar.gz",
"code_id": "1",
"file_location": "D:/数据集/cnn数据集/data2/",
"resource": {"CPU": 24, "MEMORY": 256, "NPU": 1}
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AC",
"dataset_name": "cifar-100-python.tar.gz",
"code_id": "1",
"file_location": "D:/数据集/cnn数据集/data3/",
"resource": {"CPU": 24, "MEMORY": 256, "NPU": 1}
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AD",
"dataset_name": "dev.jsonl",
"code_id": "2",
"file_location": "D:/数据集/transfomer数据集/BoolQ/",
"resource": {"CPU": 24, "MEMORY": 256, "NPU": 1}
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AE",
"dataset_name": "dev.jsonl",
"file_location": "D:/数据集/transfomer数据集/BoolQ/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AF",
"dataset_name": "ceval.zip",
"file_location": "D:/数据集/transfomer数据集/CEval/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AG",
"dataset_name": "CMMLU.zip",
"file_location": "D:/数据集/transfomer数据集/CMMLU/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AH",
"dataset_name": "mental_health.csv",
"file_location": "D:/数据集/transfomer数据集/GLUE(imdb)/imdb/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AI",
"dataset_name": "GSM8K.jsonl",
"file_location": "D:/数据集/transfomer数据集/GSM8K/GSM8K/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AJ",
"dataset_name": "human-eval.jsonl",
"file_location": "D:/数据集/transfomer数据集/HumanEval/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AK",
"dataset_name": "HumanEval_X.zip",
"file_location": "D:/数据集/transfomer数据集/HumanEval_X/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AF",
"dataset_name": "ceval.zip",
"file_location": "D:/数据集/transfomer数据集/CEval/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AG",
"dataset_name": "CMMLU.zip",
"file_location": "D:/数据集/transfomer数据集/CMMLU/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AH",
"dataset_name": "mental_health.csv",
"file_location": "D:/数据集/transfomer数据集/GLUE(imdb)/imdb/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AI",
"dataset_name": "GSM8K.jsonl",
"file_location": "D:/数据集/transfomer数据集/GSM8K/GSM8K/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AJ",
"dataset_name": "human-eval.jsonl",
"file_location": "D:/数据集/transfomer数据集/HumanEval/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
},
{
"task_name_template": "{prefix}-jointCloudAi-trainingtask",
"prefix": "AK",
"dataset_name": "HumanEval_X.zip",
"file_location": "D:/数据集/transfomer数据集/HumanEval_X/",
"code_Id": 1,
"CPU": 24,
"MEMORY": 256,
"NPU": 1
}
]
def load_tasks_to_queue(templates: List[Dict]) -> None:
"""将任务静态数据加载到任务队列task_map"""
global task_map
with task_map_lock:
task_map.clear()
for template in templates:
task_name = template["task_name_template"].format(prefix=template["prefix"])
task = TaskInfo(
task_name=task_name,
dataset_name=template["dataset_name"],
code_id=template["code_id"],
resource=template["resource"],
file_location=template["file_location"]
)
task_map[task["target_id"]] = task
logger.info(f"任务入队 | task_name: {task_name} | target_id: {task['target_id']}")
logger.info(f"任务队列加载完成 | 共 {len(task_map)} 个任务")
def select_cluster(task_resource: Dict) -> Optional[str]:
"""根据任务资源需求选择合适的集群"""
with cluster_lock:
for cluster_id, cluster in cluster_resources.items():
# 检查集群可用资源是否满足任务需求支持NPU/DCU等不同加速卡类型
resource_match = True
for res_type, required in task_resource.items():
# 集群可用资源中可能是NPU或DCU统一检查
available = cluster["available"].get(res_type, 0)
if available < required:
resource_match = False
break
if resource_match:
return cluster_id
logger.warning(f"无满足资源需求的集群 | 任务需求: {task_resource}")
return None
# -------------------------- API调用方法 --------------------------
def get_son_code_id(cluster_id: str, code_id: str) -> str:
"""根据集群ID和算法ID查询子算法ID模拟接口查询"""
son_code_map = {
("1790300942428540928", "1"): "1-1",
("1790300942428540928", "2"): "2-1",
("1777240145309732864", "1"): "1-2",
("1865927992266461184", "2"): "2-2"
}
return son_code_map.get((cluster_id, code_id), f"{code_id}-default")
#def get_auth_token() -> Optional[str]:
# -------------------------- API调用方法 --------------------------
def get_token() -> Optional[str]:
"""获取认证Token"""
login_payload = {"username": "admin", "password": "Nudt@123"}
try:
config = API_CONFIG["login"]
response = requests.post(config["url"], json=login_payload, timeout=config["timeout"])
response.raise_for_status()
result = response.json()
if result.get("code") == 200 and "data" in result and "token" in result["data"]:
logger.info("Token获取成功")
return result["data"]["token"]
else:
logger.error(f"Token获取失败 | 响应: {result}")
return None
except requests.exceptions.RequestException as e:
logger.error(f"登录请求异常: {str(e)}", exc_info=True)
return None
def submit_single_task(task: Dict) -> bool:
"""提交单个任务到集群失败时更新状态为failed/error"""
token = get_token()
if not token:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["error_msg"] = "获取Token失败"
return False
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {token}'
}
task_name = task["task_name"]
package_name = task["package_name"]
file_name = task["dataset_name"]
file_location = task["file_location"]
code_id = task["code_id"] # 修正字段名
son_code_id = get_son_code_id(task["cluster_id"], code_id) # 使用实际集群ID
file_path = os.path.join(file_location, file_name)
try:
# 第一步:创建数据集文件夹
config = API_CONFIG["create_package"]
create_payload = {
"userID": 5,
"name": package_name,
"dataType": "dataset",
"packageID": 0,
"uploadPriority": {"type": "specify", "clusters": ["1790300942428540928"]},
"bindingInfo": {
"clusterIDs": ["1790300942428540928"],
"name": package_name,
"category": "image",
"type": "dataset",
"imageID": "",
"bias": [],
"region": [],
"chip": ["ASCEND"],
"selectedCluster": [],
"modelType": "",
"env": "",
"version": "",
"packageID": 0,
"points": 0
}
}
create_resp = requests.post(config["url"], json=create_payload, headers=headers, timeout=config["timeout"])
create_resp.raise_for_status()
create_result = create_resp.json()
if create_result.get("code") != 200:
raise ValueError(f"创建文件夹失败 | API返回: {create_result}")
packageID = create_result["data"]["newPackage"]["packageID"]
logger.info(f"[{task_name}] 第一步:创建文件夹成功 | packageID: {packageID}")
# 第三步:上传数据集文件
config = API_CONFIG["upload_file"]
if not os.path.exists(file_path):
raise FileNotFoundError(f"数据集文件不存在 | path: {file_path}")
info_data = {
"userID": 5,
"packageID": packageID,
"loadTo": [3],
"loadToPath": [f"/dataset/5/{package_name}/"]
}
file_headers = {'Authorization': f'Bearer {token}'}
with open(file_path, 'rb') as f:
form_data = {"info": (None, json.dumps(info_data)), "files": f}
upload_resp = requests.post(config["url"], files=form_data, headers=file_headers, timeout=config["timeout"])
upload_resp.raise_for_status()
upload_result = upload_resp.json()
if upload_result.get("code") != 200:
raise ValueError(f"文件上传失败 | API返回: {upload_result}")
object_id = upload_result["data"]["uploadeds"][0]["objectID"]
logger.info(f"[{task_name}] 第三步:文件上传成功 | objectID: {object_id}")
# 第四步:通知上传完成
config = API_CONFIG["notify_upload"]
notify_payload = {
"userID": 5,
"packageID": packageID,
"uploadParams": {
"dataType": "dataset",
"uploadInfo": {"type": "local", "localPath": file_name, "objectIDs": [object_id]}
}
}
notify_resp = requests.post(config["url"], json=notify_payload, headers=headers, timeout=config["timeout"])
notify_resp.raise_for_status()
notify_result = notify_resp.json()
if notify_result.get("code") != 200:
raise ValueError(f"通知上传完成失败 | API返回: {notify_result}")
logger.info(f"[{task_name}] 第四步:通知上传完成成功")
# 第七步:绑定数据集到集群
config = API_CONFIG["bind_cluster"]
bind_payload = {
"userID": 5,
"info": {"type": "dataset", "packageID": packageID, "clusterIDs": ["1790300942428540928"]}
}
bind_resp = requests.post(config["url"], json=bind_payload, headers=headers, timeout=config["timeout"])
bind_resp.raise_for_status()
bind_result = bind_resp.json()
if bind_result.get("code") != 200:
raise ValueError(f"绑定集群失败 | API返回: {bind_result}")
logger.info(f"[{task_name}] 第七步:数据集绑定集群成功")
# 第八步查询绑定ID
config = API_CONFIG["query_binding"]
query_bind_payload = {
"dataType": "dataset",
"param": {"userID": 5, "bindingID": -1, "type": "private"}
}
query_bind_resp = requests.post(config["url"], json=query_bind_payload, headers=headers, timeout=config["timeout"]).json()
if query_bind_resp.get("code") != 200:
raise ValueError(f"查询绑定失败 | API返回: {query_bind_resp}")
# 提取目标绑定ID
target_id = None
for data in query_bind_resp["data"]["datas"]:
if data["info"]["name"] == package_name:
target_id = data["ID"]
break
if not target_id:
raise ValueError(f"未找到package_name={package_name}的绑定ID")
logger.info(f"[{task_name}] 第八步获取绑定ID成功 | target_id: {target_id}")
# 第九步:提交训练任务
config = API_CONFIG["submit_task"]
task_res = task["resource"]
submit_payload = {
"userID": 5,
"jobSetInfo": {
"jobs": [
{
"localJobID": "1",
"name": task_name,
"description": "自动提交的CNN训练任务",
"type": "AI",
"files": {
"dataset": {"type": "Binding", "bindingID": target_id},
"model": {"type": "Binding", "bindingID": 421},
"image": {"type": "Image", "imageID": 11}
},
"jobResources": {
"scheduleStrategy": "dataLocality",
"clusters": [
{
"clusterID": "1790300942428540928",
"runtime": {"envs": {}, "params": {}},
"code": {"type": "Binding", "bindingID": son_code_id},
"resources": [
{"type": "CPU", "name": "ARM", "number": task_res["CPU"]},
{"type": "MEMORY", "name": "RAM", "number": task_res["MEMORY"]},
{"type": "MEMORY", "name": "VRAM", "number": 32},
{"type": "STORAGE", "name": "DISK", "number": 32},
{"type": "NPU", "name": "ASCEND910", "number": task_res.get("NPU", 0)}
]
}
]
}
},
{"localJobID": "4", "type": "DataReturn", "targetLocalJobID": "1"}
]
}
}
submit_resp = requests.post(config["url"], json=submit_payload, headers=headers, timeout=config["timeout"]).json()
if submit_resp.get("code") != 200:
raise ValueError(f"任务提交失败 | API返回: {submit_resp}")
third_party_task_id = submit_resp.get('data', {}).get('jobSetID')
logger.info(f"[{task_name}] 第九步:任务提交成功 | 第三方任务ID: {third_party_task_id}")
# 更新任务状态为成功(线程安全)
with task_map_lock:
task["status"] = TASK_STATUS["SUCCEED"]
task["third_party_task_id"] = third_party_task_id # 保存第三方任务ID
return True
except Exception as e:
error_msg = f"提交失败: {str(e)}"
with task_map_lock:
# 检查是否达到最大重试次数
task["fail_count"] += 1
if task["fail_count"] >= task["max_fail_threshold"]:
task["status"] = TASK_STATUS["RETRY_EXHAUSTED"]
else:
task["status"] = TASK_STATUS["FAILED"] # 未达最大次数标记为failed等待重试
task["error_msg"] = error_msg
logger.error(f"[{task_name}] {error_msg}", exc_info=True)
return False
def query_third_party_task_status(third_party_task_id: str) -> Optional[str]:
"""
查询云际平台任务状态实际API调用
返回subTaskInfos[]中第一个元素的status值
"""
if not third_party_task_id:
logger.warning("第三方任务ID为空无法查询状态")
return None
try:
# 构建请求参数ID作为查询参数
config = API_CONFIG["task_details"]
params = {"id": third_party_task_id}
# 发送请求注意任务详情接口可能需要Token认证此处补充认证逻辑
token = get_token()
headers = {"Authorization": f"Bearer {token}"} if token else {}
response = requests.get(
config["url"],
params=params,
headers=headers,
timeout=config["timeout"]
)
response.raise_for_status() # 抛出HTTP错误状态码
result = response.json()
# 解析响应结果
if result.get("code") != 200:
logger.error(f"查询任务状态失败 | 任务ID: {third_party_task_id} | 响应: {result}")
return None
# 提取subTaskInfos中的status
sub_task_infos = result.get("data", {}).get("subTaskInfos", [])
if not sub_task_infos:
logger.warning(f"任务 {third_party_task_id} 未找到subTaskInfos数据")
return None
# 返回第一个子任务的status
return sub_task_infos[0].get("status")
except requests.exceptions.RequestException as e:
logger.error(f"查询任务状态请求异常 | 任务ID: {third_party_task_id} | 错误: {str(e)}", exc_info=True)
return None
except (KeyError, IndexError) as e:
logger.error(f"解析任务状态响应失败 | 任务ID: {third_party_task_id} | 错误: {str(e)}", exc_info=True)
return None
# -------------------------- 线程一:任务监控线程 --------------------------
class TaskMonitorThread(threading.Thread):
"""监控线程:专注监控任务状态,仅处理提交中任务的状态更新"""
def __init__(self, check_interval: int = 10):
super().__init__(name="TaskMonitorThread")
self.check_interval = check_interval # 监控间隔(秒)
self._stop_event = threading.Event()
def run(self) -> None:
logger.info(f"监控线程启动 | 监控间隔: {self.check_interval}")
while not self._stop_event.is_set():
with task_map_lock:
tasks = list(task_map.values()) # 复制任务列表,避免线程安全问题
for task in tasks:
with task_map_lock:
current_status = task["status"]
# 1. 待提交状态:不处理
if current_status == TASK_STATUS["SUBMITTED"]:
continue
# 2. 提交中状态:定时查询第三方状态并更新
elif current_status == TASK_STATUS["SUBMITTING"]:
if not task["third_party_task_id"]:
logger.warning(f"任务 {task['task_name']} 无第三方ID跳过状态查询")
continue
# 查询第三方状态
third_status = query_third_party_task_status(task["third_party_task_id"])# 根据第三方返回的id查询任务状态
with task_map_lock:
# 2.1 第三方状态为成功:更新为提交成功,记录成功时间
if third_status == "SUCCEEDED":
task["status"] = TASK_STATUS["SUCCEED"]
task["success_time"] = time.strftime("%Y-%m-%d %H:%M:%S")
logger.info(
f"任务状态更新 | task_name: {task['task_name']} | 提交成功 | 成功时间: {task['success_time']}")
# 2.2 第三方状态为失败:更新为提交失败,失败次数+1
elif third_status == "FAILED":
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1
task["error_msg"] = f"第三方任务执行失败(第{task['fail_count']}次)"
logger.warning(
f"任务状态更新 | task_name: {task['task_name']} | 提交失败 | 失败次数: {task['fail_count']}/{task['max_fail_threshold']}")
# 2.3 第三方状态为提交中:不更新状态
# 3. 提交成功状态:不处理
elif current_status == TASK_STATUS["SUCCEED"]:
continue
# 4. 提交失败状态:不处理(由提交线程判断是否重试)
elif current_status == TASK_STATUS["FAILED"]:
continue
# 检查是否所有任务已完成(成功或失败次数超阈值)
all_completed = self._check_all_tasks_completed()
if all_completed:
logger.info("所有任务已完成(成功或失败次数超过阈值)")
self.stop()
# 等待下次监控
self._stop_event.wait(self.check_interval)
logger.info("监控线程结束")
def _check_all_tasks_completed(self) -> bool:
"""检查是否所有任务已完成(成功或失败次数超阈值)"""
with task_map_lock:
for task in task_map.values():
# 待提交或提交中:未完成
if task["status"] in [TASK_STATUS["SUBMITTED"], TASK_STATUS["SUBMITTING"]]:
return False
# 提交失败但次数未超阈值:未完成(可能被提交线程重试)
if task["status"] == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]:
return False
return True
def stop(self) -> None:
self._stop_event.set()
# -------------------------- 线程二:任务提交线程 --------------------------
class TaskSubmitThread(threading.Thread):
"""提交线程:按状态判断是否提交,处理待提交和未超阈值的失败任务"""
def __init__(self, max_workers: int = 3):
super().__init__(name="TaskSubmitThread")
self.max_workers = max_workers # 并发提交数
self._stop_event = threading.Event()
def run(self) -> None:
logger.info(f"提交线程启动 | 并发数: {self.max_workers}")
while not self._stop_event.is_set():
# 1. 筛选符合条件的任务:待提交 或 失败次数未超阈值的提交失败任务
with task_map_lock:
pending_tasks = []
for task in task_map.values():
status = task["status"]
# 1.1 待提交状态:直接提交
if status == TASK_STATUS["SUBMITTED"]:
pending_tasks.append(task)
# 1.2 提交失败状态:检查失败次数,未超阈值则提交
elif status == TASK_STATUS["FAILED"]:
if task["fail_count"] < task["max_fail_threshold"]:
pending_tasks.append(task)
else:
logger.info(
f"任务 {task['task_name']} 失败次数超阈值({task['max_fail_threshold']}),停止提交")
# if not pending_tasks:
# logger.info("无待提交任务,等待下次检查")
# self._stop_event.wait(5)
# continue
# 2. 并发提交任务
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {executor.submit(self.commit_task, task): task for task in pending_tasks}
for future in concurrent.futures.as_completed(futures):
task = futures[future]
try:
future.result()
except Exception as e:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1
task["error_msg"] = f"提交过程异常: {str(e)}"
logger.error(f"任务提交异常 | task_name: {task['task_name']} | 错误: {str(e)}")
logger.info("提交线程结束")
def commit_task(self, task: Dict) -> None:
"""提交任务的入口先选集群再调用submit_single_task"""
# 1. 选择集群并更新任务
cluster_id = select_cluster(task["resource"])
if not cluster_id:
with task_map_lock:
task["status"] = TASK_STATUS["FAILED"]
task["fail_count"] += 1
task["error_msg"] = "无可用集群"
logger.error(f"[{task['task_name']}] 提交失败:无可用集群")
return
# 2. 标记任务为提交中
with task_map_lock:
task["status"] = TASK_STATUS["SUBMITTING"]
task["cluster_id"] = cluster_id # 记录集群ID
logger.info(f"[{task['task_name']}] 开始提交至集群 {cluster_id}")
# 3. 调用核心提交方法
submit_success = submit_single_task(task)
if not submit_success:
logger.warning(f"[{task['task_name']}] 提交失败,等待重试(当前失败次数:{task['fail_count']}")
# def stop(self) -> None:
# self._stop_event.set()
# -------------------------- 主程序 --------------------------
if __name__ == "__main__":
# 1. 生成任务静态数据
task_templates = generate_task_templates()
# 2. 读取任务进入队列
load_tasks_to_queue(task_templates)
# 3. 启动监控线程
monitor_thread = TaskMonitorThread(check_interval=10)
monitor_thread.start()
# 4. 启动提交线程
submit_thread = TaskSubmitThread(max_workers=3)
submit_thread.start()
# 5. 等待线程结束
monitor_thread.join()
submit_thread.join()
logger.info("所有任务处理完毕,程序退出")