PullRequest: 48 Misc changes for supporting the async worker in the future

Merge branch fw/async-worker of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/48

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
This commit is contained in:
博惟 2025-03-20 10:18:26 +08:00
parent b8f1fc3ebf
commit 5759579aa8
7 changed files with 112 additions and 34 deletions

View File

@ -684,7 +684,6 @@ class SequenceSample:
class DataBatchMeta:
dp_rank: int
meta_sample: SequenceSample | None
is_final_batch: bool
@dataclasses.dataclass

View File

@ -60,3 +60,7 @@ def distributed_local_peer(experiment_name, trial_name, host_name, model_name):
def distributed_master(experiment_name, trial_name, model_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/distributed/master/{model_name}"
def model_version(experiment_name, trial_name, model_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/model_version/{model_name}"

View File

@ -6,17 +6,29 @@ import socket
from contextlib import closing
def find_free_port(low=1, high=65536):
"""From stackoverflow Issue 1365265."""
def find_free_port(low=1, high=65536, exclude_ports=None):
"""Find a free port within the specified range, excluding certain ports."""
if exclude_ports is None:
exclude_ports = set()
while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
if low <= port <= high:
if low <= port <= high and port not in exclude_ports:
return port
def find_multiple_free_ports(count, low=1, high=65536):
"""Find multiple mutually exclusive free ports."""
free_ports = set()
for _ in range(count):
port = find_free_port(low, high, exclude_ports=free_ports)
free_ports.add(port)
return list(free_ports)
def gethostname():
return socket.gethostname()

View File

@ -718,16 +718,6 @@ class CommonExperimentConfig(Experiment):
)
else:
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [
"vllm",
"sglang",
]:
raise ValueError(
"vLLM or SGLang is not enabled for generation. "
"This behavior has been deprecated. "
"Please set model.vllm.hybrid_train=True "
"or model.sglang.hybrid_train=True."
)
if mapping[i, j]:
shard_idx = shard_counter[model_name]

View File

@ -35,6 +35,8 @@ from realhf.base import (
constants,
gpu_utils,
logging,
name_resolve,
names,
network,
recover,
seeding,
@ -374,7 +376,8 @@ class ModelWorker(worker_base.Worker):
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
f"dataset_indices_{self._dp_rank}.npy",
"dataset_indices",
f"{self._dp_rank}.npy",
)
if os.path.exists(dataset_indices_path):
indices = np.load(dataset_indices_path).tolist()
@ -455,8 +458,8 @@ class ModelWorker(worker_base.Worker):
self.__request_cache = {}
self.__ack_cache = {}
self.__request_queue = queue.Queue(maxsize=8)
self.__reply_queue = queue.Queue(maxsize=8)
self.__request_queue = queue.Queue(maxsize=10240)
self.__reply_queue = queue.Queue(maxsize=10240)
self.__request_sample_size = dict()
self.__compute_input_queues = {
@ -468,6 +471,13 @@ class ModelWorker(worker_base.Worker):
for model_name in self.__models.keys()
}
# By intention, must be smaller than -1.
self._last_param_realloc_step = -100
if self.__recover_run:
self._last_param_realloc_step = (
self.__recover_info.last_step_info.global_step
)
def __handle_one_rpc_hook(self, hook: str, hook_data: Any):
ret = None
@ -579,8 +589,10 @@ class ModelWorker(worker_base.Worker):
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
f"dataset_indices_{dp_rank}.npy",
"dataset_indices",
f"{dp_rank}.npy",
)
os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True)
if hasattr(self.__dataset, "filter") and os.path.exists(
eval_scores_path
):
@ -624,9 +636,6 @@ class ModelWorker(worker_base.Worker):
res = data_api.DataBatchMeta(
dp_rank=dp_rank,
meta_sample=meta_sample,
is_final_batch=(
self.__dataset_batch_counter == len(self.__dataloader) - 1
),
)
elif request.handle_name == "spec":
# Raw dataset without filtering.
@ -736,6 +745,31 @@ class ModelWorker(worker_base.Worker):
assert isinstance(res, dict), res
res.update({f"eval_{k}": v for k, v in ret.items()})
# update param realloc step after handling post hooks
if request.handle_name == "train_step":
self._last_param_realloc_step = max(self._last_param_realloc_step + 1, 1)
realloc_dir = os.path.join(
constants.PARAM_REALLOC_PATH,
constants.experiment_name(),
constants.trial_name(),
model_name.role,
)
save_meta = dict(
model_name=model_name,
save_backend=False,
save_dir=realloc_dir,
)
self.__save_model(save_meta)
name = names.model_version(
self.__experiment_name,
self.__trial_name,
model_name.role,
)
with constants.model_scope(model_name):
dist.barrier(group=constants.parallelism_group())
if constants.parallelism_rank() == 0:
name_resolve.add_subentry(name, str(self._last_param_realloc_step))
self.__reply_queue.put_nowait((request, res))
sample_count = data.bs if isinstance(data, data_api.SequenceSample) else 1
self.__request_sample_size[request.request_id] = sample_count
@ -1163,11 +1197,12 @@ class ModelWorker(worker_base.Worker):
@cuda_tmark("post_response", CUDATimeMarkType.misc)
def maybe_post_responses(self):
ready_to_post = []
try:
request, res = self.__reply_queue.get_nowait()
ready_to_post.append((request, res))
except queue.Empty:
pass
while True:
try:
request, res = self.__reply_queue.get_nowait()
ready_to_post.append((request, res))
except queue.Empty:
break
batch_size = sample_size = 0
for request, res in ready_to_post:

View File

@ -52,7 +52,7 @@ class Payload:
syn_reply_id: uuid.UUID = None
ack_reply_id: uuid.UUID = None
no_syn: bool = False
no_syn: bool = True
send_time: float = None
@ -164,7 +164,7 @@ class NameResolvingRequestClient:
datas: List[Any] | None = None,
payloads: List[Payload] | None = None,
verbose: bool = True,
no_syn: bool = False,
no_syn: bool = True,
) -> List[uuid.UUID]:
"""Send requests of type `handle_type` to all `handlers` with
corresponding `data`.

View File

@ -1,7 +1,7 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import dataclasses
import enum
import os
@ -547,6 +547,18 @@ class Worker:
"""Implemented by sub-classes."""
raise NotImplementedError()
@property
def running(self):
return self.__running
@property
def exiting(self):
return self.__exiting
@property
def is_configured(self):
return self.__is_configured
def configure(
self,
worker_info: system_api.WorkerInformation,
@ -625,10 +637,11 @@ class Worker:
def _exit_hook(self, exit_status: WorkerServerStatus):
logger.warning(f"Exit with {exit_status}, hook not implemented, pass.")
def exit(self):
def exit(self, err: bool = False):
self.logger.info("Exiting worker")
self._exit_hook(WorkerServerStatus.COMPLETED)
self.__set_status(WorkerServerStatus.COMPLETED)
status = WorkerServerStatus.ERROR if err else WorkerServerStatus.COMPLETED
self._exit_hook(status)
self.__set_status(status)
self.__exiting = True
def interrupt(self):
@ -681,8 +694,7 @@ class Worker:
logger.error(f"Worker encountered error {e}", exc_info=True)
if isinstance(e, WorkerException):
raise e
self.__set_status(WorkerServerStatus.ERROR)
self._exit_hook(WorkerServerStatus.ERROR)
self.exit(err=True)
raise e
def __host_key(self, key: str):
@ -694,6 +706,32 @@ class Worker:
name_resolve.watch_names(keys, call_back=self.exit)
class AsyncWorker(Worker):
async def _poll_async(self) -> PollResult:
raise NotImplementedError()
async def run_async(self):
self.logger.debug("Running worker now")
try:
while not self.exiting:
await asyncio.sleep(0.0)
self._server.handle_requests()
if not self.running:
await asyncio.sleep(0.05)
continue
if not self.is_configured:
raise RuntimeError("Worker is not configured")
r = await self._poll_async()
except KeyboardInterrupt:
self.exit()
except Exception as e:
logger.error(f"Worker encountered error {e}", exc_info=True)
if isinstance(e, WorkerException):
raise e
self.exit(err=True)
raise e
class MappingThread:
"""Wrapped of a mapping thread.