This commit is contained in:
晓雷 2025-08-01 16:11:07 +08:00
parent 04b26f42bb
commit 0291191716
3 changed files with 13 additions and 37 deletions

View File

@ -323,7 +323,7 @@ def prepare_batch(
workflow: "RolloutWorkflow",
):
if not hasattr(self, "data_generator"):
self.data_generator = iter(dataloader)
self.data_generator = itertools.cycle(dataloader)
assert dataloader.batch_size is not None
while True:
# Submit at least two batches to allow maximum overlap
@ -332,11 +332,7 @@ def prepare_batch(
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
try:
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)
data = next(self.data_generator)
for item in data:
# submit data into input_queue
self.submit(item, workflow=workflow)
@ -364,18 +360,13 @@ rollout = RemoteSGLangEngine(config.inf_engine)
rollout.initialize()
eval_rollout = ...
data_generator = iter(train_dataloader)
data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
# rollout batched training data for current step
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
```
If you want to use rollout workflows with custom reward functions or agentic tool
@ -474,17 +465,12 @@ Now a complete GRPO training step in AReaL-lite is done! The core logic of our e
training script can be summarized as:
```python
data_generator = iter(train_dataloader)
data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp

View File

@ -79,7 +79,7 @@ and converting it into an `LLMRequest` object for the inference engine:
class MultiTurnWorkflow(RolloutWorkflow):
# ... __init__ method above ...
async def arun_episode(self, engine: InferenceEngine, data):
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# Initialize result containers
seq, logprobs, loss_mask, versions = [], [], [], []
messages = data["messages"]
@ -124,7 +124,7 @@ apply a discount, add feedback to the conversation, and let the model try again:
class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data):
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# ... initialization code ...
while reward == 0 and t < self.max_turns:
# Add feedback if the previous answer was incorrect
@ -195,7 +195,7 @@ Finally, let's complete the implementation by collecting trajectories in the
class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data):
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# ... episode logic above ...
while reward == 0 and t < self.max_turns:
@ -250,18 +250,13 @@ def main(args):
)
# Run training—no other changes needed!
data_generator = iter(train_dataloader)
data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
# ... continue with training loop ...
```

View File

@ -171,16 +171,11 @@ def main(args):
)
# Main training loop
data_generator = itertools.cycle(dataloader)
for global_step in range(max_steps):
# Generate training data
with stats_tracker.record_timing("rollout"):
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
batch = batch.to(actor.device)