mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
88e99f887a
commit
bd8df23a13
|
@ -513,6 +513,10 @@ def tp_and_pp_group():
|
|||
return grid().get_model_parallel_group()
|
||||
|
||||
|
||||
def tp_and_pp_cpu_group():
|
||||
return grid().ds_model_proc_group_gloo
|
||||
|
||||
|
||||
def tp_and_pp_rank():
|
||||
"""Used as the rank in the world group of vLLM."""
|
||||
return grid().get_model_parallel_rank()
|
||||
|
|
|
@ -21,7 +21,7 @@ def check_is_realhf_native_impl(_cls):
|
|||
def check_is_realhf_native_model_interface(name):
|
||||
# NOTE: we should not import iterfaces here,
|
||||
# such that we can avoid CUDA initialization.
|
||||
return name in ["ppo_actor", "ppo_critic", "sft", "reward", "fused-threading"]
|
||||
return name in ["ppo_actor", "ppo_critic", "sft", "rw-math-code", "fused-threading"]
|
||||
|
||||
|
||||
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
|
||||
|
|
|
@ -202,7 +202,9 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
dispatched = {}
|
||||
dispatched_indices = {}
|
||||
for task_idx, task_name in enumerate(RL_TASKS):
|
||||
indices = (data.data["task_ids"] == task_idx).numpy().nonzero()[0].tolist()
|
||||
indices = (
|
||||
(data.data["task_ids"] == task_idx).cpu().numpy().nonzero()[0].tolist()
|
||||
)
|
||||
if len(indices) > 0:
|
||||
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
|
||||
dispatched_indices[task_name] = indices
|
||||
|
@ -221,6 +223,53 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
assert all(xs)
|
||||
return SequenceSample.gather(xs)
|
||||
|
||||
def _dispatch_tp_and_pp(self, data: SequenceSample):
|
||||
tp_pp_size = constants.tp_and_pp_world_size()
|
||||
if tp_pp_size == 1:
|
||||
return data, None
|
||||
splitted, _, backward_indices = data.split(
|
||||
mb_spec=MicroBatchSpec(n_mbs=tp_pp_size)
|
||||
)
|
||||
tp_pp_rank = constants.tp_and_pp_rank()
|
||||
print("dispatched batch size", [s.bs for s in splitted], flush=True)
|
||||
return splitted[tp_pp_rank], backward_indices
|
||||
|
||||
def _gather_tp_and_pp(self, input_, data: SequenceSample, backward_indices):
|
||||
tp_pp_size = constants.tp_and_pp_world_size()
|
||||
if tp_pp_size == 1:
|
||||
return data
|
||||
local_rank = constants.grid().topo.get_rank(
|
||||
data=constants.data_parallel_rank(),
|
||||
model=0,
|
||||
pipe=constants.pipe_parallel_world_size() - 1,
|
||||
)
|
||||
dst = constants.to_global_pg_rank(local_rank)
|
||||
gather_list = None
|
||||
if dist.get_rank() == dst:
|
||||
gather_list = [None for _ in range(tp_pp_size)]
|
||||
x = data.data["rewards"].cpu().numpy().tolist()
|
||||
print(x, flush=True)
|
||||
dist.gather_object(
|
||||
x, gather_list, dst=dst, group=constants.tp_and_pp_cpu_group()
|
||||
)
|
||||
if dist.get_rank() != dst:
|
||||
return None
|
||||
gathered = np.array(gather_list).reshape(-1, self.group_size)
|
||||
assert len(gathered) == len(backward_indices)
|
||||
rewards = (
|
||||
np.concatenate([gathered[i] for i in backward_indices]).flatten().tolist()
|
||||
)
|
||||
return SequenceSample(
|
||||
keys=["rewards"],
|
||||
trailing_shapes=dict(rewards=()),
|
||||
dtypes=dict(rewards=torch.float32),
|
||||
ids=input_.ids,
|
||||
seqlens=dict(
|
||||
rewards=[[1 for _ in range(self.group_size)] for _ in range(input_.bs)],
|
||||
),
|
||||
data=dict(rewards=torch.tensor(rewards, dtype=torch.float32)),
|
||||
)
|
||||
|
||||
def calculate_task_reward(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
|
@ -266,22 +315,17 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
|
||||
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
|
||||
|
||||
# NOTE: a place holder
|
||||
dense_scores = torch.zeros_like(packed_input_ids, dtype=torch.float32)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["rewards", "dense_rewards"],
|
||||
trailing_shapes=dict(rewards=(), dense_rewards=()),
|
||||
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
|
||||
keys=["rewards"],
|
||||
trailing_shapes=dict(rewards=()),
|
||||
dtypes=dict(rewards=torch.float32),
|
||||
ids=data.ids,
|
||||
seqlens=dict(
|
||||
rewards=[
|
||||
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
|
||||
for x in data.seqlens["packed_input_ids"]
|
||||
[1 for _ in range(len(x))] for x in data.seqlens["packed_input_ids"]
|
||||
],
|
||||
dense_rewards=data.seqlens["packed_input_ids"],
|
||||
),
|
||||
data=dict(rewards=scores, dense_rewards=dense_scores),
|
||||
data=dict(rewards=scores),
|
||||
)
|
||||
|
||||
# record rewards for each piece of data
|
||||
|
@ -317,6 +361,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
def log_rewards_to_file(
|
||||
self, task_type: str, model: model_api.Model, prompt_strs, seq_strs, scores
|
||||
):
|
||||
tik = time.perf_counter()
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
|
@ -367,6 +412,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
|
||||
logger.info(f"[{task_type}] number of samples: {len(scores)}, {scores.shape}")
|
||||
logger.info(f"[{task_type}] avg reward: {sum(scores) / len(scores)}")
|
||||
logger.info(f"[{task_type}] log to file time: {time.perf_counter()- tik:.2f}s")
|
||||
|
||||
def inference(
|
||||
self,
|
||||
|
@ -374,6 +420,8 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
input_ = data
|
||||
data, backward_indices = self._dispatch_tp_and_pp(data)
|
||||
task_data, dispatch_indices = self._dispatch_tasks(data)
|
||||
|
||||
assert self.rw_type == "sparse"
|
||||
|
@ -410,11 +458,9 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
|
||||
return task_results
|
||||
|
||||
if constants.is_dp_head():
|
||||
task_results = asyncio.run(_run_tasks())
|
||||
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
|
||||
else:
|
||||
final_result = None
|
||||
task_results = asyncio.run(_run_tasks())
|
||||
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
|
||||
final_result = self._gather_tp_and_pp(input_, final_result, backward_indices)
|
||||
|
||||
model.inc_version()
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ def math_code_dataset(request, save_path):
|
|||
@pytest.mark.parametrize(
|
||||
"dp,pp,mp",
|
||||
[
|
||||
(1, 1, 1),
|
||||
(2, 2, 1),
|
||||
# (2, 1, 2),
|
||||
# (1, 2, 1),
|
||||
# (1, 1, 2),
|
||||
|
@ -114,6 +114,7 @@ def test_ppo_symm(
|
|||
use_cuda_graph=False,
|
||||
),
|
||||
),
|
||||
group_size=2,
|
||||
)
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
|
|
Loading…
Reference in New Issue