mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
04b26f42bb
commit
0291191716
|
@ -323,7 +323,7 @@ def prepare_batch(
|
|||
workflow: "RolloutWorkflow",
|
||||
):
|
||||
if not hasattr(self, "data_generator"):
|
||||
self.data_generator = iter(dataloader)
|
||||
self.data_generator = itertools.cycle(dataloader)
|
||||
assert dataloader.batch_size is not None
|
||||
while True:
|
||||
# Submit at least two batches to allow maximum overlap
|
||||
|
@ -332,11 +332,7 @@ def prepare_batch(
|
|||
and self.input_queue.qsize() + dataloader.batch_size
|
||||
< self.input_queue.maxsize
|
||||
):
|
||||
try:
|
||||
data = next(self.data_generator)
|
||||
except StopIteration:
|
||||
self.data_generator = iter(dataloader)
|
||||
data = next(self.data_generator)
|
||||
data = next(self.data_generator)
|
||||
for item in data:
|
||||
# submit data into input_queue
|
||||
self.submit(item, workflow=workflow)
|
||||
|
@ -364,18 +360,13 @@ rollout = RemoteSGLangEngine(config.inf_engine)
|
|||
rollout.initialize()
|
||||
eval_rollout = ...
|
||||
|
||||
data_generator = iter(train_dataloader)
|
||||
data_generator = itertools.cycle(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
# rollout batched training data for current step
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
```
|
||||
|
||||
If you want to use rollout workflows with custom reward functions or agentic tool
|
||||
|
@ -474,17 +465,12 @@ Now a complete GRPO training step in AReaL-lite is done! The core logic of our e
|
|||
training script can be summarized as:
|
||||
|
||||
```python
|
||||
data_generator = iter(train_dataloader)
|
||||
data_generator = itertools.cycle(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
|
||||
logp = actor.compute_logp(batch)
|
||||
batch["prox_logp"] = logp
|
||||
|
|
|
@ -79,7 +79,7 @@ and converting it into an `LLMRequest` object for the inference engine:
|
|||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... __init__ method above ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||
# Initialize result containers
|
||||
seq, logprobs, loss_mask, versions = [], [], [], []
|
||||
messages = data["messages"]
|
||||
|
@ -124,7 +124,7 @@ apply a discount, add feedback to the conversation, and let the model try again:
|
|||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... previous methods ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||
# ... initialization code ...
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# Add feedback if the previous answer was incorrect
|
||||
|
@ -195,7 +195,7 @@ Finally, let's complete the implementation by collecting trajectories in the
|
|||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... previous methods ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||
# ... episode logic above ...
|
||||
|
||||
while reward == 0 and t < self.max_turns:
|
||||
|
@ -250,18 +250,13 @@ def main(args):
|
|||
)
|
||||
|
||||
# Run training—no other changes needed!
|
||||
data_generator = iter(train_dataloader)
|
||||
data_generator = itertools.cycle(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
# ... continue with training loop ...
|
||||
```
|
||||
|
||||
|
|
|
@ -171,16 +171,11 @@ def main(args):
|
|||
)
|
||||
|
||||
# Main training loop
|
||||
data_generator = itertools.cycle(dataloader)
|
||||
for global_step in range(max_steps):
|
||||
# Generate training data
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue