This commit is contained in:
bowei.fw 2025-07-14 16:39:59 +08:00
commit d8bd161b1d
5 changed files with 24 additions and 27 deletions

View File

@ -92,7 +92,7 @@ def main_grpo():
future.result()
# synchronous rollout
rollout_batch = rollout.rollout(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
# or asynchronous rollout with filtering and off-policyness control
# rollout_batch = rollout.prepare_batch(batch,
# workflow=MyRolloutWorkflow(rollout_config.workflow),
@ -697,7 +697,7 @@ reward = TrainController(Critic())
rollout_controller = RolloutController(...)
for _ in range(epochs):
for _ in range(steps_per_epoch):
data = rollout_controller.rollout(prompt)
data = rollout_controller.rollout_batch(prompt)
data['reward'] = reward.compute_values(data)
...
```

View File

@ -339,7 +339,6 @@ class SGLangConfig:
model_path: str = ""
random_seed: int = 1
skip_tokenizer_init: bool = False
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_cuda_graph_padding: bool = False
@ -375,10 +374,8 @@ class SGLangConfig:
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
dtype: str = "float16"
kv_cache_dtype: str = "auto"
# logging
log_level: str = "warning"
log_level_http: Optional[str] = "warning"
@ -439,7 +436,6 @@ class SGLangConfig:
raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
)
if version_less_than_0_4_4:
args.pop("log_requests_level")
if version_less_than_0_4_3:
@ -449,7 +445,6 @@ class SGLangConfig:
args.pop("enable_memory_saver")
args.pop("allow_auto_truncate")
args.pop("file_storage_path")
flags = []
for k, v in args.items():
if v is None or v is False or v == "":
@ -506,11 +501,6 @@ class InferenceEngineConfig:
)
@dataclass
class SGLangEngineConfig:
pass
@dataclass
class _Timer:
experiment_name: str = MISSING
@ -807,7 +797,6 @@ def parse_cli_args(argv: List[str]):
"--config", help="The path of the main configuration file", required=True
)
args, overrides = parser.parse_known_args(argv)
# Initialize hydra config
config_file = Path(args.config).absolute()
assert config_file.exists()
@ -834,7 +823,6 @@ def load_expr_config(argv: List[str], config_cls):
cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg)
assert isinstance(cfg, BaseExperimentConfig)
# Setup environment
from realhf.base import constants, name_resolve, names

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.io_struct import (
FinetuneSpec,
@ -78,8 +79,8 @@ class TrainEngine(abc.ABC):
def train_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> Dict[str, float]:
"""Update the model with a batch of data and a loss function."""
raise NotImplementedError()
@ -88,8 +89,8 @@ class TrainEngine(abc.ABC):
def eval_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> torch.Tensor | None:
"""Evaluate the model using the forward pass and loss function."""
raise NotImplementedError()
@ -99,7 +100,7 @@ class TrainEngine(abc.ABC):
self,
input_: TensorDict,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
@ -136,6 +137,20 @@ class InferenceEngine(abc.ABC):
"""Wait for a specified number of requests to complete, with a timeout."""
raise NotImplementedError()
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
raise NotImplementedError()
def prepare_batch(
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
):
"""Asynchronously submit and wait until a full batch is ready."""
raise NotImplementedError()
def pause(self):
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
raise NotImplementedError()
@ -143,9 +158,3 @@ class InferenceEngine(abc.ABC):
def resume(self):
"""Resume request submission for async rollout."""
raise NotImplementedError()
def rollout(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
raise NotImplementedError()

View File

@ -481,7 +481,7 @@ class RemoteSGLangEngine(InferenceEngine):
)
return concat_padded_tensors(results)
def rollout(
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""

View File

@ -108,7 +108,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
result = engine.rollout([data] * 2, workflow=workflow)
result = engine.rollout_batch([data] * 2, workflow=workflow)
assert isinstance(result, TensorDict)
bs = result.batch_size
assert bs == torch.Size([2 * n_samples])