AReaL/docs/developer/master_worker.md

2.1 KiB

Master Worker

Overview

The worker architecture of AReaL consists of a single master worker coordinating multiple model workers.

An RL algorithm typically contains several model function calls (MFCs) that need to be executed in a certain order. For example in PPO,

  1. actor_gen generates responses given a batch of user prompts;
  2. ref_inf computes the log-probabilities of the tokens under the reference policy;
  3. rew_inf computes the rewards of the responses;
  4. actor_train updates the policy with the PPO learning objective.

Here model function calls 2 and 3 depends on the output of 1. Model function call 4 depends on the outputs of 1, 2, and 3.

The MFCs are coordinated by a FunctionExecutor instance. It creates a ModelFunctionCall instance for each MFC. The actual computation is performed on model workers via remote procedure call.

Buffer and MFC Execution Order

The master worker creates a AsyncIOSequenceBuffer, which is referenced by the FunctionExecutor and the ModelFunctionCall's. The buffer is responsible for managing (meta)data and deciding the execution order of the MFCs.

Each datapoint can be seen as a dict of tensors. For example, the keys may include packed_prompts and task_ids. Recall that some MFC may rely on the output of another. For example in PPO, the MFC ref_inf requires packed_input_ids, which is not presented initially. Instead, packed_input_ids appears as one of the results of the MFC actor_gen.

The buffer keeps track of the available keys of each datapoint. Each ModelFunctionCallinstance obtains its next batch via self.get_batch_for_rpc, which waits for enough datapoints with all the required keys. This means that it would not start execution until all required keys are ready. After a model function call execution, it calls self.amend_batch and updates the corresponding datapoints with new keys.

While some keys are the results of MFCs, some are loaded from the dataset via FunctionExecutor.load_data. Also note that instead of the actual data, the buffer stores only metadata (data indices, keys, etc.) to reduce the cost of data transfer.