diff --git a/commit_task_jointcloud_isdataupload_0721.py b/commit_task_jointcloud_isdataupload_0721.py new file mode 100644 index 0000000..27be555 --- /dev/null +++ b/commit_task_jointcloud_isdataupload_0721.py @@ -0,0 +1,846 @@ +import concurrent.futures +import time +import logging +import threading +from uuid import uuid4 +from typing import Dict, List, Optional +import requests +import os +import json +import yaml # 用于处理YAML文件 + +# -------------------------- 全局配置与常量定义 -------------------------- +# 日志配置 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler()] +) +logger = logging.getLogger(__name__) + +# 任务状态定义 +TASK_STATUS = { + "SUBMITTED": "待提交", # 初始状态 + "SUBMITTING": "提交中", # 提交过程中 + "SUCCEED": "提交成功", # 云际确认成功 + "FAILED": "提交失败", # 云际确认失败 + "RETRY_EXHAUSTED": "重试耗尽" # 超过最大失败次数 +} + +# 全局任务字典(key=target_id,value=任务详情) +task_map: Dict[str, Dict] = {} +task_map_lock = threading.Lock() # 任务字典线程锁 + +# 全局数据集映射(key=file_location,value=DatasetInfo实例) +dataset_map: Dict[str, Dict] = {} +dataset_lock = threading.Lock() # 数据集映射线程锁 + +# 子算法-集群-数据集映射(从YAML加载) +ALGORITHM_MAPPING: Dict[int, Dict] = {} # 结构: {son_code_id: {"clusters": [], "file_location": ""}} + +# 任务模板(从YAML加载) +TASK_TEMPLATES: List[Dict] = [] + +# API配置 +API_CONFIG = { + "login": { + "url": "http://119.45.255.234:30180/jcc-admin/admin/login", + "timeout": 10 + }, + "create_package": { + "url": "http://119.45.255.234:30180/jsm/jobSet/createPackage", + "timeout": 15 + }, + "upload_file": { + "url": "http://121.36.5.116:32010/object/upload", + "timeout": 3000 + }, + "notify_upload": { + "url": "http://119.45.255.234:30180/jsm/jobSet/notifyUploaded", + "timeout": 15 + }, + "bind_cluster": { + "url": "http://119.45.255.234:30180/jsm/jobSet/binding", + "timeout": 15 + }, + "query_binding": { + "url": "http://119.45.255.234:30180/jsm/jobSet/queryBinding", + "timeout": 15 + }, + "submit_task": { + "url": "http://119.45.255.234:30180/jsm/jobSet/submit", + "timeout": 100 + }, + "task_details": { + "url": "http://119.45.255.234:30180/pcm/v1/core/task/details", + "timeout": 15 + } +} + +# 集群资源配置(key=集群ID,value=总资源/可用资源) +cluster_resources: Dict[str, Dict] = { + "1790300942428540928": { # modelarts集群 + "total": {"CPU": 1024, "MEMORY": 2048, "NPU": 56}, + "available": {"CPU": 1024, "MEMORY": 2048, "NPU": 56} + }, + "1790300942428540929": { # 新增集群示例 + "total": {"CPU": 512, "MEMORY": 1024, "NPU": 32}, + "available": {"CPU": 512, "MEMORY": 1024, "NPU": 32} + } +} +cluster_lock = threading.Lock() # 集群资源线程锁 + + +# -------------------------- 数据结构定义 -------------------------- +class DatasetInfo(dict): + """数据集信息结构""" + + def __init__(self, file_location: str, name: str, size: float, **kwargs): + super().__init__() + self["file_location"] = file_location # 本地路径(主键) + self["id"] = kwargs.get("id", str(uuid4())) # 数据集唯一标识 + self["name"] = name # 数据集名称 + self["size"] = size # 大小(字节) + self["is_uploaded"] = kwargs.get("is_uploaded", False) # 是否已上传 + self["upload_cluster"] = kwargs.get("upload_cluster", []) # 上传的集群 + self["upload_time"] = kwargs.get("upload_time") # 上传时间 + self["description"] = kwargs.get("description") # 描述 + self["dataset_target_id"] = kwargs.get("dataset_target_id") # 绑定ID + + +class TaskInfo(dict): + """任务信息结构""" + + def __init__(self, task_name: str, dataset_name: str, son_code_id: int, resource: Dict, **kwargs): + super().__init__() + self["target_id"] = kwargs.get("target_id", str(uuid4())) # 任务唯一ID + self["task_name"] = task_name # 任务名称 + self["package_name"] = kwargs.get("package_name", f"{task_name.lower()}-pkg") # 文件夹名称 + self["dataset_name"] = dataset_name # 关联数据集名称 + self["son_code_id"] = son_code_id # 子算法ID + self["resource"] = resource # 资源需求(CPU/MEMORY/NPU等) + self["status"] = TASK_STATUS["SUBMITTED"] # 初始状态:待提交 + self["submit_time"] = kwargs.get("submit_time", time.strftime("%Y-%m-%d %H:%M:%S")) # 提交时间 + self["success_time"] = None # 成功时间(成功时填充) + self["third_party_task_id"] = "" # 云际任务ID(提交后填充) + self["file_location"] = kwargs.get("file_location", "") # 本地文件路径(从映射获取) + self["error_msg"] = "" # 错误信息 + self["fail_count"] = 0 # 失败次数 + self["max_fail_threshold"] = kwargs.get("max_fail_threshold", 3) # 最大失败阈值 + self["cluster_id"] = "" # 提交的集群ID(从映射获取) + + +# -------------------------- YAML配置文件处理 -------------------------- +def load_yaml_config(yaml_path: str = "sonCode_cluster__mapping.yaml") -> None: + """加载YAML配置文件(包含算法映射和任务模板)""" + # 加载配置文件 + try: + with open(str, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + ALGORITHM_MAPPING = data.get("algorithm_mapping", {}) + TASK_TEMPLATES = data.get("task_templates", []) + logger.info( + f"成功加载配置文件 | 子算法映射: {len(ALGORITHM_MAPPING)} 条 | 任务模板: {len(TASK_TEMPLATES)} 条") + except Exception as e: + logger.error(f"加载配置文件失败: {str(e)}", exc_info=True) + +# -------------------------- 工具方法 -------------------------- +def generate_task_templates() -> List[Dict]: + """从全局变量返回任务模板(实际从YAML加载)""" + return TASK_TEMPLATES + + +def load_tasks_to_queue(templates: List[Dict]) -> None: + """将任务静态数据加载到任务队列(从映射获取file_location)""" + global task_map + with task_map_lock: + task_map.clear() + for idx, template in enumerate(templates): + try: + # 检查必填字段 + required_fields = ["task_name_template", "prefix", "dataset_name", "son_code_id", "resource"] + missing_fields = [f for f in required_fields if f not in template] + if missing_fields: + logger.warning(f"跳过无效任务模板(索引{idx}):缺少字段 {missing_fields}") + continue + + # 从映射获取file_location + son_code_id = template["son_code_id"] + mapping = ALGORITHM_MAPPING.get(son_code_id) + if not mapping: + logger.warning(f"子算法ID {son_code_id} 无映射配置,跳过任务模板(索引{idx})") + continue + file_location = mapping["file_location"] + + # 生成任务名称 + task_name = template["task_name_template"].format(prefix=template["prefix"]) + # 创建任务实例 + task = TaskInfo( + task_name=task_name, + dataset_name=template["dataset_name"], + son_code_id=son_code_id, + resource=template["resource"], + file_location=file_location # 从映射填充 + ) + task_map[task["target_id"]] = task + logger.info( + f"任务入队 | task_name: {task_name} | 子算法ID: {son_code_id} | 数据集路径: {file_location}") + except Exception as e: + logger.error(f"加载任务模板失败(索引{idx}):{str(e)}") + logger.info(f"任务队列加载完成 | 共 {len(task_map)} 个有效任务") + + +def select_cluster(task_resource: Dict, son_code_id: int) -> Optional[str]: + """根据任务资源需求和子算法ID选择合适的集群(优先从映射中选择)""" + # 1. 从映射获取该子算法支持的集群 + mapping = ALGORITHM_MAPPING.get(son_code_id) + if not mapping: + logger.warning(f"子算法ID {son_code_id} 无集群映射,尝试所有集群") + candidate_clusters = list(cluster_resources.keys()) + else: + candidate_clusters = mapping["clusters"] + + with cluster_lock: + # 2. 检查候选集群是否满足资源需求 + for cluster_id in candidate_clusters: + if cluster_id not in cluster_resources: + logger.warning(f"映射中集群 {cluster_id} 不存在于资源配置中,跳过") + continue + cluster = cluster_resources[cluster_id] + # 检查资源是否满足 + resource_match = True + for res_type, required in task_resource.items(): + available = cluster["available"].get(res_type, 0) + if available < required: + resource_match = False + break + if resource_match: + # 占用资源 + for res_type, required in task_resource.items(): + if res_type in cluster["available"]: + cluster["available"][res_type] -= required + logger.info(f"选中集群 {cluster_id}(子算法 {son_code_id} 映射)| 更新后可用资源: {cluster['available']}") + return cluster_id + + # 3. 若映射中的集群不满足,尝试其他集群 + all_clusters = list(cluster_resources.keys()) + for cluster_id in all_clusters: + if cluster_id in candidate_clusters: + continue # 已检查过 + cluster = cluster_resources[cluster_id] + resource_match = True + for res_type, required in task_resource.items(): + available = cluster["available"].get(res_type, 0) + if available < required: + resource_match = False + break + if resource_match: + for res_type, required in task_resource.items(): + if res_type in cluster["available"]: + cluster["available"][res_type] -= required + logger.info(f"选中集群 {cluster_id}(非映射)| 更新后可用资源: {cluster['available']}") + return cluster_id + + logger.warning(f"无满足资源需求的集群 | 任务需求: {task_resource} | 子算法: {son_code_id}") + return None + + +# -------------------------- 数据集上传判断与处理方法 -------------------------- +def check_and_handle_dataset(file_location: str, dataset_name: str, cluster_id: str) -> Optional[str]: + """检查数据集是否已上传到指定集群,未上传则执行上传""" + global dataset_map + + # 步骤1: 检查数据集是否已上传 + dataset = get_dataset_status(file_location, dataset_name, cluster_id) + + # 步骤2: 若未上传,则执行上传 + if not dataset or not dataset["is_uploaded"] or cluster_id not in dataset["upload_cluster"]: + dataset = upload_dataset(file_location, dataset_name, cluster_id) + if not dataset: + return None # 上传失败 + + return dataset["dataset_target_id"] + + +def get_dataset_status(file_location: str, dataset_name: str, cluster_id: str) -> Optional[Dict]: + """检查数据集状态(是否已上传到指定集群)""" + with dataset_lock: + if file_location in dataset_map: + dataset = dataset_map[file_location] + # 验证是否已上传到目标集群 + if dataset["is_uploaded"] and cluster_id in dataset["upload_cluster"]: + logger.info( + f"数据集 {dataset_name} 已上传到集群 {cluster_id} | target_id: {dataset['dataset_target_id']}") + return dataset + return None + + +def upload_dataset(file_location: str, dataset_name: str, cluster_id: str) -> Optional[Dict]: + """执行数据集上传流程""" + dataset_path = os.path.join(file_location, dataset_name) + + # 检查本地文件是否存在 + if not os.path.exists(dataset_path): + logger.error(f"数据集本地文件不存在 | path: {dataset_path}") + return None + + # 计算文件大小(字节) + try: + file_size = os.path.getsize(dataset_path) + except OSError as e: + logger.error(f"获取文件大小失败 | path: {dataset_path} | 错误: {str(e)}") + return None + + logger.info(f"开始上传数据集 {dataset_name} 到集群 {cluster_id} | path: {dataset_path}") + + try: + # 获取认证Token + token = get_token() + if not token: + logger.error("获取Token失败,无法上传数据集") + return None + + headers = {'Authorization': f'Bearer {token}'} + package_name = f"dataset-{dataset_name.split('.')[0]}-{uuid4().hex[:6]}" # 生成唯一文件夹名 + + # 1. 创建数据集文件夹 + create_payload = { + "userID": 5, + "name": package_name, + "dataType": "dataset", + "packageID": 0, + "uploadPriority": {"type": "specify", "clusters": [cluster_id]}, + "bindingInfo": { + "clusterIDs": [cluster_id], + "name": package_name, + "category": "image", + "type": "dataset", + "imageID": "", + "chip": ["ASCEND"], + "packageID": 0 + } + } + create_resp = requests.post( + API_CONFIG["create_package"]["url"], + json=create_payload, + headers=headers, + timeout=API_CONFIG["create_package"]["timeout"] + ) + create_resp.raise_for_status() + create_result = create_resp.json() + if create_result.get("code") != "OK": + raise ValueError(f"创建文件夹失败 | 响应: {create_result}") + packageID = create_result["data"]["newPackage"]["packageID"] + logger.info(f"数据集文件夹创建成功 | packageID: {packageID}") + + # 2. 上传文件 + info_data = { + "userID": 5, + "packageID": packageID, + "loadTo": [3], + "loadToPath": [f"/dataset/5/{package_name}/"] + } + with open(dataset_path, 'rb') as f: + form_data = {"info": (None, json.dumps(info_data)), "files": f} + upload_resp = requests.post( + API_CONFIG["upload_file"]["url"], + files=form_data, + headers=headers, + timeout=API_CONFIG["upload_file"]["timeout"] + ) + upload_resp.raise_for_status() + upload_result = upload_resp.json() + if upload_result.get("code") != "OK": + raise ValueError(f"文件上传失败 | 响应: {upload_result}") + object_id = upload_result["data"]["uploadeds"][0]["objectID"] + logger.info(f"数据集文件上传成功 | objectID: {object_id}") + + # 3. 通知上传完成 + notify_payload = { + "userID": 5, + "packageID": packageID, + "uploadParams": { + "dataType": "dataset", + "uploadInfo": {"type": "local", "localPath": dataset_name, "objectIDs": [object_id]} + } + } + notify_resp = requests.post( + API_CONFIG["notify_upload"]["url"], + json=notify_payload, + headers=headers, + timeout=API_CONFIG["notify_upload"]["timeout"] + ) + notify_resp.raise_for_status() + + # 4. 绑定到集群 + bind_payload = { + "userID": 5, + "info": {"type": "dataset", "packageID": packageID, "clusterIDs": [cluster_id]} + } + bind_resp = requests.post( + API_CONFIG["bind_cluster"]["url"], + json=bind_payload, + headers=headers, + timeout=API_CONFIG["bind_cluster"]["timeout"] + ) + bind_resp.raise_for_status() + + # 5. 查询绑定ID(dataset_target_id) + query_resp = requests.post( + API_CONFIG["query_binding"]["url"], + json={"dataType": "dataset", "param": {"userID": 5, "bindingID": -1, "type": "private"}}, + headers=headers, + timeout=API_CONFIG["query_binding"]["timeout"] + ).json() + if query_resp.get("code") != "OK": + raise ValueError(f"查询绑定ID失败 | 响应: {query_resp}") + + dataset_target_id = None + for item in query_resp["data"]["datas"]: + if item["info"]["name"] == package_name: + dataset_target_id = item["ID"] + break + if not dataset_target_id: + raise ValueError(f"未找到数据集 {package_name} 的绑定ID") + + # 上传成功,创建并保存数据集信息 + dataset = DatasetInfo( + file_location=file_location, + name=dataset_name, + size=file_size, + is_uploaded=True, + upload_cluster=[cluster_id], + upload_time=time.strftime("%Y-%m-%d %H:%M:%S"), + dataset_target_id=dataset_target_id + ) + + with dataset_lock: + dataset_map[file_location] = dataset + logger.info(f"数据集 {dataset_name} 上传成功 | target_id: {dataset_target_id}") + + return dataset + + except Exception as e: + logger.error(f"数据集上传失败 | name: {dataset_name} | 错误: {str(e)}", exc_info=True) + return None + + +# -------------------------- API调用方法 -------------------------- +def get_token() -> Optional[str]: + """获取认证Token""" + login_payload = {"username": "admin", "password": "Nudt@123"} + try: + config = API_CONFIG["login"] + response = requests.post(config["url"], json=login_payload, timeout=config["timeout"]) + response.raise_for_status() + result = response.json() + if result.get("code") == 200 and "data" in result and "token" in result["data"]: + logger.info("Token获取成功") + return result["data"]["token"] + else: + logger.error(f"Token获取失败 | 响应: {result}") + return None + except requests.exceptions.RequestException as e: + logger.error(f"登录请求异常: {str(e)}", exc_info=True) + return None + + +def submit_single_task(task: Dict) -> bool: + """提交单个任务到集群""" + token = get_token() + if not token: + with task_map_lock: + task["status"] = TASK_STATUS["FAILED"] + task["error_msg"] = "获取Token失败" + return False + + # 获取选择的集群ID + cluster_id = task.get("cluster_id") + if not cluster_id: + with task_map_lock: + task["status"] = TASK_STATUS["FAILED"] + task["error_msg"] = "未指定集群ID" + logger.error(f"[{task['task_name']}] 提交失败:未指定集群ID") + return False + + # 数据集检查与上传 + dataset_target_id = check_and_handle_dataset( + file_location=task["file_location"], + dataset_name=task["dataset_name"], + cluster_id=cluster_id + ) + if not dataset_target_id: + # 数据集上传失败,释放集群资源 + with cluster_lock: + if cluster_id in cluster_resources: + for res_type, required in task["resource"].items(): + if res_type in cluster_resources[cluster_id]["available"]: + cluster_resources[cluster_id]["available"][res_type] += required + with task_map_lock: + task["status"] = TASK_STATUS["FAILED"] + task["error_msg"] = "数据集上传失败" + logger.error(f"[{task['task_name']}] 提交失败:数据集处理失败") + return False + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {token}' + } + task_name = task["task_name"] + package_name = task["package_name"] + son_code_id = task["son_code_id"] # 子算法ID + + try: + # 1. 创建数据集文件夹 + config = API_CONFIG["create_package"] + create_payload = { + "userID": 5, + "name": package_name, + "dataType": "dataset", + "packageID": 0, + "uploadPriority": {"type": "specify", "clusters": [cluster_id]}, + "bindingInfo": { + "clusterIDs": [cluster_id], + "name": package_name, + "category": "image", + "type": "dataset", + "imageID": "", + "bias": [], + "region": [], + "chip": ["ASCEND"], + "selectedCluster": [], + "modelType": "", + "env": "", + "version": "", + "packageID": 0, + "points": 0 + } + } + create_resp = requests.post(config["url"], json=create_payload, headers=headers, timeout=config["timeout"]) + create_resp.raise_for_status() + create_result = create_resp.json() + if create_result.get("code") != "OK": + raise ValueError(f"创建文件夹失败 | API返回: {create_result}") + packageID = create_result["data"]["newPackage"]["packageID"] + logger.info(f"[{task_name}] 第一步:创建文件夹成功 | packageID: {packageID} | 集群: {cluster_id}") + + # 2. 通知上传完成 + config = API_CONFIG["notify_upload"] + notify_payload = { + "userID": 5, + "packageID": packageID, + "uploadParams": { + "dataType": "dataset", + "uploadInfo": {"type": "local", "localPath": task["dataset_name"], "objectIDs": []} + } + } + notify_resp = requests.post(config["url"], json=notify_payload, headers=headers, timeout=config["timeout"]) + notify_resp.raise_for_status() + logger.info(f"[{task_name}] 第二步:通知上传完成成功") + + # 3. 绑定数据集到集群 + config = API_CONFIG["bind_cluster"] + bind_payload = { + "userID": 5, + "info": {"type": "dataset", "packageID": packageID, "clusterIDs": [cluster_id]} + } + bind_resp = requests.post(config["url"], json=bind_payload, headers=headers, timeout=config["timeout"]) + bind_resp.raise_for_status() + logger.info(f"[{task_name}] 第三步:数据集绑定集群 {cluster_id} 成功") + + # 4. 提交训练任务 + config = API_CONFIG["submit_task"] + task_res = task["resource"] + + submit_payload = { + "userID": 5, + "jobSetInfo": { + "jobs": [ + { + "localJobID": "1", + "name": task_name, + "description": "自动提交的训练任务", + "type": "AI", + "files": { + "dataset": {"type": "Binding", "bindingID": dataset_target_id}, + "model": {"type": "Binding", "bindingID": ""}, + "image": {"type": "Image", "imageID": 11} + }, + "jobResources": { + "scheduleStrategy": "dataLocality", + "clusters": [ + { + "clusterID": cluster_id, + "runtime": {"envs": {}, "params": {}}, + "code": {"type": "Binding", "bindingID": son_code_id}, + "resources": [ + {"type": "CPU", "name": "ARM", "number": task_res["CPU"]}, + {"type": "MEMORY", "name": "RAM", "number": task_res["MEMORY"]}, + {"type": "MEMORY", "name": "VRAM", "number": 32}, + {"type": "STORAGE", "name": "DISK", "number": 886}, + {"type": "NPU", "name": "ASCEND910", "number": task_res.get("NPU", 0)} + ] + } + ] + } + } + ] + } + } + + response = requests.post( + config["url"], + json=submit_payload, + headers=headers, + timeout=config["timeout"] + ) + response.raise_for_status() + submit_resp = response.json() + if submit_resp.get("code") != "OK": + raise ValueError(f"任务提交失败 | API返回: {submit_resp}") + + third_party_task_id = submit_resp.get('data', {}).get('jobSetID') + logger.info(f"[{task_name}] 第四步:任务提交至集群 {cluster_id} 成功 | 云际任务ID: {third_party_task_id}") + + # 更新任务状态为成功 + with task_map_lock: + task["status"] = TASK_STATUS["SUCCEED"] + task["third_party_task_id"] = third_party_task_id + return True + + except Exception as e: + error_msg = f"提交失败: {str(e)}" + # 任务失败时释放集群资源 + with cluster_lock: + if cluster_id in cluster_resources: + for res_type, required in task["resource"].items(): + if res_type in cluster_resources[cluster_id]["available"]: + cluster_resources[cluster_id]["available"][res_type] += required + logger.info(f"任务失败,释放集群 {cluster_id} 资源: {task['resource']}") + + with task_map_lock: + task["fail_count"] += 1 + if task["fail_count"] >= task["max_fail_threshold"]: + task["status"] = TASK_STATUS["RETRY_EXHAUSTED"] + else: + task["status"] = TASK_STATUS["FAILED"] + task["error_msg"] = error_msg + logger.error(f"[{task_name}] {error_msg}", exc_info=True) + return False + + +def query_third_party_task_status(third_party_task_id: str) -> Optional[Dict]: + """查询云际平台任务状态""" + if not third_party_task_id: + logger.warning("云际任务ID为空,无法查询状态") + return None + try: + url = "http://119.45.255.234:30180/pcm/v1/core/task/list" + params = { + "pageNum": 1, + "pageSize": 10, + "type": 1 + } + token = get_token() + if not token: + logger.error("获取Token失败,无法查询任务状态") + return None + headers = {"Authorization": f"Bearer {token}"} if token else {} + + response = requests.get( + url, + params=params, + headers=headers, + timeout=15 + ) + response.raise_for_status() + result = response.json() + + if result.get("code") != 200 or "data" not in result or "list" not in result["data"]: + logger.error(f"查询任务状态接口返回异常 | 响应: {result}") + return None + + task_list = result["data"]["list"] + target_task = None + for task in task_list: + if task.get("id") == third_party_task_id: + target_task = task + break + if not target_task: + logger.warning(f"未找到任务 {third_party_task_id}") + return None + + task_status = target_task.get("status") + end_time = target_task.get("endTime") + logger.info(f"任务 {third_party_task_id} 状态: {task_status} | 结束时间: {end_time or '未结束'}") + return { + "status": task_status, + "end_time": end_time + } + except Exception as e: + logger.error(f"查询任务状态失败 | 任务ID: {third_party_task_id} | 错误: {str(e)}", exc_info=True) + return None + + +# -------------------------- 线程一:任务监控线程 -------------------------- +class TaskMonitorThread(threading.Thread): + """监控线程:专注监控任务状态""" + + def __init__(self, check_interval: int = 10): + super().__init__(name="TaskMonitorThread") + self.check_interval = check_interval + self._stop_event = threading.Event() + + def run(self) -> None: + logger.info(f"监控线程启动 | 监控间隔: {self.check_interval}秒") + while not self._stop_event.is_set(): + with task_map_lock: + tasks = list(task_map.values()) + + for task in tasks: + with task_map_lock: + current_status = task["status"] + + if current_status in [TASK_STATUS["SUBMITTED"], TASK_STATUS["RETRY_EXHAUSTED"]]: + continue + + if current_status == TASK_STATUS["SUBMITTING"] and task["third_party_task_id"]: + third_party_info = query_third_party_task_status(task["third_party_task_id"]) + if third_party_info: + with task_map_lock: + if third_party_info["status"] == "Succeeded": + task["status"] = TASK_STATUS["SUCCEED"] + task["success_time"] = third_party_info["end_time"] + logger.info(f"任务 {task['task_name']} 成功 | 时间: {task['success_time']}") + elif third_party_info["status"] in ["Failed", "Saved"]: + task["status"] = TASK_STATUS["FAILED"] + task["fail_count"] += 1 + task["error_msg"] = f"云际任务失败(第{task['fail_count']}次)" + logger.warning(f"任务 {task['task_name']} 失败 | 次数: {task['fail_count']}") + + if self._check_all_completed(): + logger.info("所有任务已完成") + self.stop() + + self._stop_event.wait(self.check_interval) + logger.info("监控线程结束") + + def _check_all_completed(self) -> bool: + """检查所有任务是否完成""" + with task_map_lock: + for task in task_map.values(): + if task["status"] in [TASK_STATUS["SUBMITTED"], TASK_STATUS["SUBMITTING"]]: + return False + if task["status"] == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]: + return False + return True + + def stop(self) -> None: + self._stop_event.set() + + +# -------------------------- 线程二:任务提交线程 -------------------------- +class TaskSubmitThread(threading.Thread): + """提交线程:按状态判断是否提交""" + + def __init__(self, max_workers: int = 3): + super().__init__(name="TaskSubmitThread") + self.max_workers = max_workers + self._stop_event = threading.Event() + + def run(self) -> None: + logger.info(f"提交线程启动 | 并发数: {self.max_workers}") + while not self._stop_event.is_set(): + # 筛选待提交任务 + with task_map_lock: + pending_tasks = [] + for task in task_map.values(): + status = task["status"] + if status == TASK_STATUS["SUBMITTED"]: + pending_tasks.append(task) + elif status == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]: + pending_tasks.append(task) + elif status == TASK_STATUS["FAILED"]: + logger.info(f"任务 {task['task_name']} 失败次数超限,停止提交") + + # 并发提交 + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = {executor.submit(self.commit_task, task): task for task in pending_tasks} + for future in concurrent.futures.as_completed(futures): + task = futures[future] + try: + future.result() + except Exception as e: + with task_map_lock: + task["status"] = TASK_STATUS["FAILED"] + task["fail_count"] += 1 + task["error_msg"] = f"提交异常: {str(e)}" + logger.error(f"任务 {task['task_name']} 提交异常 | 错误: {str(e)}") + + if self._check_all_completed(): + logger.info("所有任务已完成,提交线程退出") + self.stop() + break + + if not pending_tasks: + logger.info("无待提交任务,等待5秒") + self._stop_event.wait(5) + + logger.info("提交线程结束") + + def commit_task(self, task: Dict) -> None: + """提交任务入口:基于映射选择集群""" + son_code_id = task["son_code_id"] + cluster_id = select_cluster(task["resource"], son_code_id) + if not cluster_id: + with task_map_lock: + task["status"] = TASK_STATUS["FAILED"] + task["fail_count"] += 1 + task["error_msg"] = "无可用集群" + logger.error(f"任务 {task['task_name']} 提交失败:无可用集群") + return + + # 标记为提交中 + with task_map_lock: + task["status"] = TASK_STATUS["SUBMITTING"] + task["cluster_id"] = cluster_id + logger.info(f"任务 {task['task_name']} 开始提交至集群 {cluster_id}(子算法 {son_code_id})") + + # 执行提交 + submit_success = submit_single_task(task) + if not submit_success: + logger.warning(f"任务 {task['task_name']} 提交失败,等待重试(当前次数:{task['fail_count']})") + + def _check_all_completed(self) -> bool: + """检查所有任务是否完成""" + with task_map_lock: + for task in task_map.values(): + if task["status"] in [TASK_STATUS["SUBMITTED"], TASK_STATUS["SUBMITTING"]]: + return False + if task["status"] == TASK_STATUS["FAILED"] and task["fail_count"] < task["max_fail_threshold"]: + return False + return True + + def stop(self) -> None: + self._stop_event.set() + + +# -------------------------- 主程序 -------------------------- +if __name__ == "__main__": + # 加载YAML配置(包含算法映射和任务模板) + load_yaml_config() + # 生成任务模板并加载到队列 + task_templates = generate_task_templates() + load_tasks_to_queue(task_templates) + + # 启动线程 + monitor_thread = TaskMonitorThread(check_interval=10) + monitor_thread.start() + + submit_thread = TaskSubmitThread(max_workers=3) + submit_thread.start() + + # 等待线程完成 + monitor_thread.join() + submit_thread.join() + logger.info("所有任务处理完毕,程序退出") \ No newline at end of file