This commit is contained in:
bowei.fw 2025-07-14 16:39:59 +08:00
commit d8bd161b1d
5 changed files with 24 additions and 27 deletions

View File

@ -92,7 +92,7 @@ def main_grpo():
future.result() future.result()
# synchronous rollout # synchronous rollout
rollout_batch = rollout.rollout(batch, workflow=MyRolloutWorkflow(rollout_config.workflow)) rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
# or asynchronous rollout with filtering and off-policyness control # or asynchronous rollout with filtering and off-policyness control
# rollout_batch = rollout.prepare_batch(batch, # rollout_batch = rollout.prepare_batch(batch,
# workflow=MyRolloutWorkflow(rollout_config.workflow), # workflow=MyRolloutWorkflow(rollout_config.workflow),
@ -697,7 +697,7 @@ reward = TrainController(Critic())
rollout_controller = RolloutController(...) rollout_controller = RolloutController(...)
for _ in range(epochs): for _ in range(epochs):
for _ in range(steps_per_epoch): for _ in range(steps_per_epoch):
data = rollout_controller.rollout(prompt) data = rollout_controller.rollout_batch(prompt)
data['reward'] = reward.compute_values(data) data['reward'] = reward.compute_values(data)
... ...
``` ```

View File

@ -339,7 +339,6 @@ class SGLangConfig:
model_path: str = "" model_path: str = ""
random_seed: int = 1 random_seed: int = 1
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_cuda_graph_padding: bool = False disable_cuda_graph_padding: bool = False
@ -375,10 +374,8 @@ class SGLangConfig:
schedule_policy: str = "lpm" schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0 cpu_offload_gb: int = 0
dtype: str = "float16" dtype: str = "float16"
kv_cache_dtype: str = "auto" kv_cache_dtype: str = "auto"
# logging # logging
log_level: str = "warning" log_level: str = "warning"
log_level_http: Optional[str] = "warning" log_level_http: Optional[str] = "warning"
@ -439,7 +436,6 @@ class SGLangConfig:
raise ValueError( raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd." "A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
) )
if version_less_than_0_4_4: if version_less_than_0_4_4:
args.pop("log_requests_level") args.pop("log_requests_level")
if version_less_than_0_4_3: if version_less_than_0_4_3:
@ -449,7 +445,6 @@ class SGLangConfig:
args.pop("enable_memory_saver") args.pop("enable_memory_saver")
args.pop("allow_auto_truncate") args.pop("allow_auto_truncate")
args.pop("file_storage_path") args.pop("file_storage_path")
flags = [] flags = []
for k, v in args.items(): for k, v in args.items():
if v is None or v is False or v == "": if v is None or v is False or v == "":
@ -506,11 +501,6 @@ class InferenceEngineConfig:
) )
@dataclass
class SGLangEngineConfig:
pass
@dataclass @dataclass
class _Timer: class _Timer:
experiment_name: str = MISSING experiment_name: str = MISSING
@ -807,7 +797,6 @@ def parse_cli_args(argv: List[str]):
"--config", help="The path of the main configuration file", required=True "--config", help="The path of the main configuration file", required=True
) )
args, overrides = parser.parse_known_args(argv) args, overrides = parser.parse_known_args(argv)
# Initialize hydra config # Initialize hydra config
config_file = Path(args.config).absolute() config_file = Path(args.config).absolute()
assert config_file.exists() assert config_file.exists()
@ -834,7 +823,6 @@ def load_expr_config(argv: List[str], config_cls):
cfg = to_structured_cfg(cfg, config_cls=config_cls) cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg) cfg = OmegaConf.to_object(cfg)
assert isinstance(cfg, BaseExperimentConfig) assert isinstance(cfg, BaseExperimentConfig)
# Setup environment # Setup environment
from realhf.base import constants, name_resolve, names from realhf.base import constants, name_resolve, names

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch import torch
from tensordict import TensorDict from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.io_struct import ( from arealite.api.io_struct import (
FinetuneSpec, FinetuneSpec,
@ -78,8 +79,8 @@ class TrainEngine(abc.ABC):
def train_batch( def train_batch(
self, self,
input_: TensorDict, input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float], loss_weight_fn: Callable[[TensorDict], float],
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Update the model with a batch of data and a loss function.""" """Update the model with a batch of data and a loss function."""
raise NotImplementedError() raise NotImplementedError()
@ -88,8 +89,8 @@ class TrainEngine(abc.ABC):
def eval_batch( def eval_batch(
self, self,
input_: TensorDict, input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float], loss_weight_fn: Callable[[TensorDict], float],
) -> torch.Tensor | None: ) -> torch.Tensor | None:
"""Evaluate the model using the forward pass and loss function.""" """Evaluate the model using the forward pass and loss function."""
raise NotImplementedError() raise NotImplementedError()
@ -99,7 +100,7 @@ class TrainEngine(abc.ABC):
self, self,
input_: TensorDict, input_: TensorDict,
output_seqlens: List[List[int]] | None = None, output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None, post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat, aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None: ) -> Any | None:
"""Run the forward pass or inference on the model. Note that it is gradient-free.""" """Run the forward pass or inference on the model. Note that it is gradient-free."""
@ -136,6 +137,20 @@ class InferenceEngine(abc.ABC):
"""Wait for a specified number of requests to complete, with a timeout.""" """Wait for a specified number of requests to complete, with a timeout."""
raise NotImplementedError() raise NotImplementedError()
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
raise NotImplementedError()
def prepare_batch(
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
):
"""Asynchronously submit and wait until a full batch is ready."""
raise NotImplementedError()
def pause(self): def pause(self):
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation.""" """Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
raise NotImplementedError() raise NotImplementedError()
@ -143,9 +158,3 @@ class InferenceEngine(abc.ABC):
def resume(self): def resume(self):
"""Resume request submission for async rollout.""" """Resume request submission for async rollout."""
raise NotImplementedError() raise NotImplementedError()
def rollout(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
raise NotImplementedError()

View File

@ -481,7 +481,7 @@ class RemoteSGLangEngine(InferenceEngine):
) )
return concat_padded_tensors(results) return concat_padded_tensors(results)
def rollout( def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow" self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict: ) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results.""" """Submit a batch of requests to the inference engine and wait for the results."""

View File

@ -108,7 +108,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
data = { data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}], "messages": [{"role": "user", "content": "Hello, how are you?"}],
} }
result = engine.rollout([data] * 2, workflow=workflow) result = engine.rollout_batch([data] * 2, workflow=workflow)
assert isinstance(result, TensorDict) assert isinstance(result, TensorDict)
bs = result.batch_size bs = result.batch_size
assert bs == torch.Size([2 * n_samples]) assert bs == torch.Size([2 * n_samples])