mirror of https://github.com/inclusionAI/AReaL
PullRequest: 48 Misc changes for supporting the async worker in the future
Merge branch fw/async-worker of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/48 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * .
This commit is contained in:
parent
b8f1fc3ebf
commit
5759579aa8
|
@ -684,7 +684,6 @@ class SequenceSample:
|
|||
class DataBatchMeta:
|
||||
dp_rank: int
|
||||
meta_sample: SequenceSample | None
|
||||
is_final_batch: bool
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
@ -60,3 +60,7 @@ def distributed_local_peer(experiment_name, trial_name, host_name, model_name):
|
|||
|
||||
def distributed_master(experiment_name, trial_name, model_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/distributed/master/{model_name}"
|
||||
|
||||
|
||||
def model_version(experiment_name, trial_name, model_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/model_version/{model_name}"
|
||||
|
|
|
@ -6,17 +6,29 @@ import socket
|
|||
from contextlib import closing
|
||||
|
||||
|
||||
def find_free_port(low=1, high=65536):
|
||||
"""From stackoverflow Issue 1365265."""
|
||||
def find_free_port(low=1, high=65536, exclude_ports=None):
|
||||
"""Find a free port within the specified range, excluding certain ports."""
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set()
|
||||
|
||||
while True:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
port = s.getsockname()[1]
|
||||
if low <= port <= high:
|
||||
if low <= port <= high and port not in exclude_ports:
|
||||
return port
|
||||
|
||||
|
||||
def find_multiple_free_ports(count, low=1, high=65536):
|
||||
"""Find multiple mutually exclusive free ports."""
|
||||
free_ports = set()
|
||||
for _ in range(count):
|
||||
port = find_free_port(low, high, exclude_ports=free_ports)
|
||||
free_ports.add(port)
|
||||
return list(free_ports)
|
||||
|
||||
|
||||
def gethostname():
|
||||
return socket.gethostname()
|
||||
|
||||
|
|
|
@ -718,16 +718,6 @@ class CommonExperimentConfig(Experiment):
|
|||
)
|
||||
else:
|
||||
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
|
||||
if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [
|
||||
"vllm",
|
||||
"sglang",
|
||||
]:
|
||||
raise ValueError(
|
||||
"vLLM or SGLang is not enabled for generation. "
|
||||
"This behavior has been deprecated. "
|
||||
"Please set model.vllm.hybrid_train=True "
|
||||
"or model.sglang.hybrid_train=True."
|
||||
)
|
||||
|
||||
if mapping[i, j]:
|
||||
shard_idx = shard_counter[model_name]
|
||||
|
|
|
@ -35,6 +35,8 @@ from realhf.base import (
|
|||
constants,
|
||||
gpu_utils,
|
||||
logging,
|
||||
name_resolve,
|
||||
names,
|
||||
network,
|
||||
recover,
|
||||
seeding,
|
||||
|
@ -374,7 +376,8 @@ class ModelWorker(worker_base.Worker):
|
|||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
f"dataset_indices_{self._dp_rank}.npy",
|
||||
"dataset_indices",
|
||||
f"{self._dp_rank}.npy",
|
||||
)
|
||||
if os.path.exists(dataset_indices_path):
|
||||
indices = np.load(dataset_indices_path).tolist()
|
||||
|
@ -455,8 +458,8 @@ class ModelWorker(worker_base.Worker):
|
|||
self.__request_cache = {}
|
||||
self.__ack_cache = {}
|
||||
|
||||
self.__request_queue = queue.Queue(maxsize=8)
|
||||
self.__reply_queue = queue.Queue(maxsize=8)
|
||||
self.__request_queue = queue.Queue(maxsize=10240)
|
||||
self.__reply_queue = queue.Queue(maxsize=10240)
|
||||
self.__request_sample_size = dict()
|
||||
|
||||
self.__compute_input_queues = {
|
||||
|
@ -468,6 +471,13 @@ class ModelWorker(worker_base.Worker):
|
|||
for model_name in self.__models.keys()
|
||||
}
|
||||
|
||||
# By intention, must be smaller than -1.
|
||||
self._last_param_realloc_step = -100
|
||||
if self.__recover_run:
|
||||
self._last_param_realloc_step = (
|
||||
self.__recover_info.last_step_info.global_step
|
||||
)
|
||||
|
||||
def __handle_one_rpc_hook(self, hook: str, hook_data: Any):
|
||||
ret = None
|
||||
|
||||
|
@ -579,8 +589,10 @@ class ModelWorker(worker_base.Worker):
|
|||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
f"dataset_indices_{dp_rank}.npy",
|
||||
"dataset_indices",
|
||||
f"{dp_rank}.npy",
|
||||
)
|
||||
os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True)
|
||||
if hasattr(self.__dataset, "filter") and os.path.exists(
|
||||
eval_scores_path
|
||||
):
|
||||
|
@ -624,9 +636,6 @@ class ModelWorker(worker_base.Worker):
|
|||
res = data_api.DataBatchMeta(
|
||||
dp_rank=dp_rank,
|
||||
meta_sample=meta_sample,
|
||||
is_final_batch=(
|
||||
self.__dataset_batch_counter == len(self.__dataloader) - 1
|
||||
),
|
||||
)
|
||||
elif request.handle_name == "spec":
|
||||
# Raw dataset without filtering.
|
||||
|
@ -736,6 +745,31 @@ class ModelWorker(worker_base.Worker):
|
|||
assert isinstance(res, dict), res
|
||||
res.update({f"eval_{k}": v for k, v in ret.items()})
|
||||
|
||||
# update param realloc step after handling post hooks
|
||||
if request.handle_name == "train_step":
|
||||
self._last_param_realloc_step = max(self._last_param_realloc_step + 1, 1)
|
||||
realloc_dir = os.path.join(
|
||||
constants.PARAM_REALLOC_PATH,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
model_name.role,
|
||||
)
|
||||
save_meta = dict(
|
||||
model_name=model_name,
|
||||
save_backend=False,
|
||||
save_dir=realloc_dir,
|
||||
)
|
||||
self.__save_model(save_meta)
|
||||
name = names.model_version(
|
||||
self.__experiment_name,
|
||||
self.__trial_name,
|
||||
model_name.role,
|
||||
)
|
||||
with constants.model_scope(model_name):
|
||||
dist.barrier(group=constants.parallelism_group())
|
||||
if constants.parallelism_rank() == 0:
|
||||
name_resolve.add_subentry(name, str(self._last_param_realloc_step))
|
||||
|
||||
self.__reply_queue.put_nowait((request, res))
|
||||
sample_count = data.bs if isinstance(data, data_api.SequenceSample) else 1
|
||||
self.__request_sample_size[request.request_id] = sample_count
|
||||
|
@ -1163,11 +1197,12 @@ class ModelWorker(worker_base.Worker):
|
|||
@cuda_tmark("post_response", CUDATimeMarkType.misc)
|
||||
def maybe_post_responses(self):
|
||||
ready_to_post = []
|
||||
try:
|
||||
request, res = self.__reply_queue.get_nowait()
|
||||
ready_to_post.append((request, res))
|
||||
except queue.Empty:
|
||||
pass
|
||||
while True:
|
||||
try:
|
||||
request, res = self.__reply_queue.get_nowait()
|
||||
ready_to_post.append((request, res))
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
batch_size = sample_size = 0
|
||||
for request, res in ready_to_post:
|
||||
|
|
|
@ -52,7 +52,7 @@ class Payload:
|
|||
syn_reply_id: uuid.UUID = None
|
||||
ack_reply_id: uuid.UUID = None
|
||||
|
||||
no_syn: bool = False
|
||||
no_syn: bool = True
|
||||
|
||||
send_time: float = None
|
||||
|
||||
|
@ -164,7 +164,7 @@ class NameResolvingRequestClient:
|
|||
datas: List[Any] | None = None,
|
||||
payloads: List[Payload] | None = None,
|
||||
verbose: bool = True,
|
||||
no_syn: bool = False,
|
||||
no_syn: bool = True,
|
||||
) -> List[uuid.UUID]:
|
||||
"""Send requests of type `handle_type` to all `handlers` with
|
||||
corresponding `data`.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import os
|
||||
|
@ -547,6 +547,18 @@ class Worker:
|
|||
"""Implemented by sub-classes."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self.__running
|
||||
|
||||
@property
|
||||
def exiting(self):
|
||||
return self.__exiting
|
||||
|
||||
@property
|
||||
def is_configured(self):
|
||||
return self.__is_configured
|
||||
|
||||
def configure(
|
||||
self,
|
||||
worker_info: system_api.WorkerInformation,
|
||||
|
@ -625,10 +637,11 @@ class Worker:
|
|||
def _exit_hook(self, exit_status: WorkerServerStatus):
|
||||
logger.warning(f"Exit with {exit_status}, hook not implemented, pass.")
|
||||
|
||||
def exit(self):
|
||||
def exit(self, err: bool = False):
|
||||
self.logger.info("Exiting worker")
|
||||
self._exit_hook(WorkerServerStatus.COMPLETED)
|
||||
self.__set_status(WorkerServerStatus.COMPLETED)
|
||||
status = WorkerServerStatus.ERROR if err else WorkerServerStatus.COMPLETED
|
||||
self._exit_hook(status)
|
||||
self.__set_status(status)
|
||||
self.__exiting = True
|
||||
|
||||
def interrupt(self):
|
||||
|
@ -681,8 +694,7 @@ class Worker:
|
|||
logger.error(f"Worker encountered error {e}", exc_info=True)
|
||||
if isinstance(e, WorkerException):
|
||||
raise e
|
||||
self.__set_status(WorkerServerStatus.ERROR)
|
||||
self._exit_hook(WorkerServerStatus.ERROR)
|
||||
self.exit(err=True)
|
||||
raise e
|
||||
|
||||
def __host_key(self, key: str):
|
||||
|
@ -694,6 +706,32 @@ class Worker:
|
|||
name_resolve.watch_names(keys, call_back=self.exit)
|
||||
|
||||
|
||||
class AsyncWorker(Worker):
|
||||
async def _poll_async(self) -> PollResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def run_async(self):
|
||||
self.logger.debug("Running worker now")
|
||||
try:
|
||||
while not self.exiting:
|
||||
await asyncio.sleep(0.0)
|
||||
self._server.handle_requests()
|
||||
if not self.running:
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
if not self.is_configured:
|
||||
raise RuntimeError("Worker is not configured")
|
||||
r = await self._poll_async()
|
||||
except KeyboardInterrupt:
|
||||
self.exit()
|
||||
except Exception as e:
|
||||
logger.error(f"Worker encountered error {e}", exc_info=True)
|
||||
if isinstance(e, WorkerException):
|
||||
raise e
|
||||
self.exit(err=True)
|
||||
raise e
|
||||
|
||||
|
||||
class MappingThread:
|
||||
"""Wrapped of a mapping thread.
|
||||
|
||||
|
|
Loading…
Reference in New Issue