From 0291191716df7dcd9247281de36fe433a8dfd668 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E9=9B=B7?= Date: Fri, 1 Aug 2025 16:11:07 +0800 Subject: [PATCH] . --- docs/arealite/gsm8k_grpo.md | 26 ++++++-------------------- docs/customization/agent.md | 15 +++++---------- docs/customization/algorithm.md | 9 ++------- 3 files changed, 13 insertions(+), 37 deletions(-) diff --git a/docs/arealite/gsm8k_grpo.md b/docs/arealite/gsm8k_grpo.md index cb4851c..7d7a3b0 100644 --- a/docs/arealite/gsm8k_grpo.md +++ b/docs/arealite/gsm8k_grpo.md @@ -323,7 +323,7 @@ def prepare_batch( workflow: "RolloutWorkflow", ): if not hasattr(self, "data_generator"): - self.data_generator = iter(dataloader) + self.data_generator = itertools.cycle(dataloader) assert dataloader.batch_size is not None while True: # Submit at least two batches to allow maximum overlap @@ -332,11 +332,7 @@ def prepare_batch( and self.input_queue.qsize() + dataloader.batch_size < self.input_queue.maxsize ): - try: - data = next(self.data_generator) - except StopIteration: - self.data_generator = iter(dataloader) - data = next(self.data_generator) + data = next(self.data_generator) for item in data: # submit data into input_queue self.submit(item, workflow=workflow) @@ -364,18 +360,13 @@ rollout = RemoteSGLangEngine(config.inf_engine) rollout.initialize() eval_rollout = ... -data_generator = iter(train_dataloader) +data_generator = itertools.cycle(train_dataloader) for global_step in range(max_steps): # rollout batched training data for current step if config.async_training: batch = rollout.prepare_batch(train_dataloader, workflow=workflow) else: - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - batch = rollout.rollout_batch(data, workflow=workflow) + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) ``` If you want to use rollout workflows with custom reward functions or agentic tool @@ -474,17 +465,12 @@ Now a complete GRPO training step in AReaL-lite is done! The core logic of our e training script can be summarized as: ```python -data_generator = iter(train_dataloader) +data_generator = itertools.cycle(train_dataloader) for global_step in range(max_steps): if config.async_training: batch = rollout.prepare_batch(train_dataloader, workflow=workflow) else: - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - batch = rollout.rollout_batch(data, workflow=workflow) + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) logp = actor.compute_logp(batch) batch["prox_logp"] = logp diff --git a/docs/customization/agent.md b/docs/customization/agent.md index b519080..5804d0e 100644 --- a/docs/customization/agent.md +++ b/docs/customization/agent.md @@ -79,7 +79,7 @@ and converting it into an `LLMRequest` object for the inference engine: class MultiTurnWorkflow(RolloutWorkflow): # ... __init__ method above ... - async def arun_episode(self, engine: InferenceEngine, data): + async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict: # Initialize result containers seq, logprobs, loss_mask, versions = [], [], [], [] messages = data["messages"] @@ -124,7 +124,7 @@ apply a discount, add feedback to the conversation, and let the model try again: class MultiTurnWorkflow(RolloutWorkflow): # ... previous methods ... - async def arun_episode(self, engine: InferenceEngine, data): + async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict: # ... initialization code ... while reward == 0 and t < self.max_turns: # Add feedback if the previous answer was incorrect @@ -195,7 +195,7 @@ Finally, let's complete the implementation by collecting trajectories in the class MultiTurnWorkflow(RolloutWorkflow): # ... previous methods ... - async def arun_episode(self, engine: InferenceEngine, data): + async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict: # ... episode logic above ... while reward == 0 and t < self.max_turns: @@ -250,18 +250,13 @@ def main(args): ) # Run training—no other changes needed! - data_generator = iter(train_dataloader) + data_generator = itertools.cycle(train_dataloader) for global_step in range(max_steps): with stats_tracker.record_timing("rollout"): if config.async_training: batch = rollout.prepare_batch(train_dataloader, workflow=workflow) else: - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - batch = rollout.rollout_batch(data, workflow=workflow) + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) # ... continue with training loop ... ``` diff --git a/docs/customization/algorithm.md b/docs/customization/algorithm.md index 734bdd1..b15a3fd 100644 --- a/docs/customization/algorithm.md +++ b/docs/customization/algorithm.md @@ -171,16 +171,11 @@ def main(args): ) # Main training loop + data_generator = itertools.cycle(dataloader) for global_step in range(max_steps): # Generate training data with stats_tracker.record_timing("rollout"): - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - - batch = rollout.rollout_batch(data, workflow=workflow) + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) batch = batch.to(actor.device)