fix single node bug (#185)

This commit is contained in:
nuzant 2025-07-18 10:54:49 +08:00 committed by GitHub
parent 0d45f43285
commit 71c47c5f17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 6 deletions

View File

@ -175,9 +175,10 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
for k in all_available_resources for k in all_available_resources
if re.match(r"node:(\b(?:\d{1,3}\.){3}\d{1,3}\b)", k) if re.match(r"node:(\b(?:\d{1,3}\.){3}\d{1,3}\b)", k)
] ]
n_gpus_per_node = int(all_available_resources["GPU"] // len(all_available_nodes)) n_nodes = len(all_available_nodes)
n_gpus_per_node = int(all_available_resources["GPU"] // n_nodes)
assert ( assert (
all_available_resources["GPU"] % len(all_available_nodes) == 0 all_available_resources["GPU"] % n_nodes == 0
), "AReaL assumes all nodes has the same number of GPUs." ), "AReaL assumes all nodes has the same number of GPUs."
for worker_type in WORKER_TYPES: for worker_type in WORKER_TYPES:
@ -201,8 +202,8 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
) )
workers = [] workers = []
if sch.scheduling.gpu > 0: if sch.scheduling.gpu > 0 and n_nodes > 1:
# For GPU workers, schedule them in granularity of nodes. # When # nodes > 1, for GPU workers, schedule them in granularity of nodes.
assert ( assert (
n_gpus_per_node % sch.scheduling.gpu == 0 n_gpus_per_node % sch.scheduling.gpu == 0
), f"Each node should be allocated with identical numbers of {worker_type}." ), f"Each node should be allocated with identical numbers of {worker_type}."
@ -256,8 +257,10 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
) )
workers.append(worker) workers.append(worker)
else: else:
# For CPU workers, schedule them with SPREAD strategy # Schedule them with SPREAD strategy when
# to save as much resource as poosible on nodes for GPU workers. # 1. CPU workers when n_nodes > 1,
# to save as much resource as possible on nodes for GPU workers.
# 2. all workers when n_nodes = 1
for _idx in range(sch.count): for _idx in range(sch.count):
worker = RayWorker.options( worker = RayWorker.options(
name=f"{worker_type}/{_idx}", name=f"{worker_type}/{_idx}",