This commit is contained in:
bowei.fw 2025-03-22 14:57:12 +08:00
parent 88e99f887a
commit bd8df23a13
4 changed files with 69 additions and 18 deletions

View File

@ -513,6 +513,10 @@ def tp_and_pp_group():
return grid().get_model_parallel_group()
def tp_and_pp_cpu_group():
return grid().ds_model_proc_group_gloo
def tp_and_pp_rank():
"""Used as the rank in the world group of vLLM."""
return grid().get_model_parallel_rank()

View File

@ -21,7 +21,7 @@ def check_is_realhf_native_impl(_cls):
def check_is_realhf_native_model_interface(name):
# NOTE: we should not import iterfaces here,
# such that we can avoid CUDA initialization.
return name in ["ppo_actor", "ppo_critic", "sft", "reward", "fused-threading"]
return name in ["ppo_actor", "ppo_critic", "sft", "rw-math-code", "fused-threading"]
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):

View File

@ -202,7 +202,9 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
dispatched = {}
dispatched_indices = {}
for task_idx, task_name in enumerate(RL_TASKS):
indices = (data.data["task_ids"] == task_idx).numpy().nonzero()[0].tolist()
indices = (
(data.data["task_ids"] == task_idx).cpu().numpy().nonzero()[0].tolist()
)
if len(indices) > 0:
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
dispatched_indices[task_name] = indices
@ -221,6 +223,53 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
assert all(xs)
return SequenceSample.gather(xs)
def _dispatch_tp_and_pp(self, data: SequenceSample):
tp_pp_size = constants.tp_and_pp_world_size()
if tp_pp_size == 1:
return data, None
splitted, _, backward_indices = data.split(
mb_spec=MicroBatchSpec(n_mbs=tp_pp_size)
)
tp_pp_rank = constants.tp_and_pp_rank()
print("dispatched batch size", [s.bs for s in splitted], flush=True)
return splitted[tp_pp_rank], backward_indices
def _gather_tp_and_pp(self, input_, data: SequenceSample, backward_indices):
tp_pp_size = constants.tp_and_pp_world_size()
if tp_pp_size == 1:
return data
local_rank = constants.grid().topo.get_rank(
data=constants.data_parallel_rank(),
model=0,
pipe=constants.pipe_parallel_world_size() - 1,
)
dst = constants.to_global_pg_rank(local_rank)
gather_list = None
if dist.get_rank() == dst:
gather_list = [None for _ in range(tp_pp_size)]
x = data.data["rewards"].cpu().numpy().tolist()
print(x, flush=True)
dist.gather_object(
x, gather_list, dst=dst, group=constants.tp_and_pp_cpu_group()
)
if dist.get_rank() != dst:
return None
gathered = np.array(gather_list).reshape(-1, self.group_size)
assert len(gathered) == len(backward_indices)
rewards = (
np.concatenate([gathered[i] for i in backward_indices]).flatten().tolist()
)
return SequenceSample(
keys=["rewards"],
trailing_shapes=dict(rewards=()),
dtypes=dict(rewards=torch.float32),
ids=input_.ids,
seqlens=dict(
rewards=[[1 for _ in range(self.group_size)] for _ in range(input_.bs)],
),
data=dict(rewards=torch.tensor(rewards, dtype=torch.float32)),
)
def calculate_task_reward(
self,
model: model_api.Model,
@ -266,22 +315,17 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
# NOTE: a place holder
dense_scores = torch.zeros_like(packed_input_ids, dtype=torch.float32)
res = SequenceSample(
keys=["rewards", "dense_rewards"],
trailing_shapes=dict(rewards=(), dense_rewards=()),
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
keys=["rewards"],
trailing_shapes=dict(rewards=()),
dtypes=dict(rewards=torch.float32),
ids=data.ids,
seqlens=dict(
rewards=[
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
for x in data.seqlens["packed_input_ids"]
[1 for _ in range(len(x))] for x in data.seqlens["packed_input_ids"]
],
dense_rewards=data.seqlens["packed_input_ids"],
),
data=dict(rewards=scores, dense_rewards=dense_scores),
data=dict(rewards=scores),
)
# record rewards for each piece of data
@ -317,6 +361,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
def log_rewards_to_file(
self, task_type: str, model: model_api.Model, prompt_strs, seq_strs, scores
):
tik = time.perf_counter()
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
@ -367,6 +412,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
logger.info(f"[{task_type}] number of samples: {len(scores)}, {scores.shape}")
logger.info(f"[{task_type}] avg reward: {sum(scores) / len(scores)}")
logger.info(f"[{task_type}] log to file time: {time.perf_counter()- tik:.2f}s")
def inference(
self,
@ -374,6 +420,8 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
input_ = data
data, backward_indices = self._dispatch_tp_and_pp(data)
task_data, dispatch_indices = self._dispatch_tasks(data)
assert self.rw_type == "sparse"
@ -410,11 +458,9 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
return task_results
if constants.is_dp_head():
task_results = asyncio.run(_run_tasks())
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
else:
final_result = None
task_results = asyncio.run(_run_tasks())
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
final_result = self._gather_tp_and_pp(input_, final_result, backward_indices)
model.inc_version()

View File

@ -53,7 +53,7 @@ def math_code_dataset(request, save_path):
@pytest.mark.parametrize(
"dp,pp,mp",
[
(1, 1, 1),
(2, 2, 1),
# (2, 1, 2),
# (1, 2, 1),
# (1, 1, 2),
@ -114,6 +114,7 @@ def test_ppo_symm(
use_cuda_graph=False,
),
),
group_size=2,
)
run_test_exp(exp_cfg)