mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
commit
d8bd161b1d
|
@ -92,7 +92,7 @@ def main_grpo():
|
|||
future.result()
|
||||
|
||||
# synchronous rollout
|
||||
rollout_batch = rollout.rollout(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
||||
rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
||||
# or asynchronous rollout with filtering and off-policyness control
|
||||
# rollout_batch = rollout.prepare_batch(batch,
|
||||
# workflow=MyRolloutWorkflow(rollout_config.workflow),
|
||||
|
@ -697,7 +697,7 @@ reward = TrainController(Critic())
|
|||
rollout_controller = RolloutController(...)
|
||||
for _ in range(epochs):
|
||||
for _ in range(steps_per_epoch):
|
||||
data = rollout_controller.rollout(prompt)
|
||||
data = rollout_controller.rollout_batch(prompt)
|
||||
data['reward'] = reward.compute_values(data)
|
||||
...
|
||||
```
|
||||
|
|
|
@ -339,7 +339,6 @@ class SGLangConfig:
|
|||
model_path: str = ""
|
||||
random_seed: int = 1
|
||||
skip_tokenizer_init: bool = False
|
||||
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
|
@ -375,10 +374,8 @@ class SGLangConfig:
|
|||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
|
||||
dtype: str = "float16"
|
||||
kv_cache_dtype: str = "auto"
|
||||
|
||||
# logging
|
||||
log_level: str = "warning"
|
||||
log_level_http: Optional[str] = "warning"
|
||||
|
@ -439,7 +436,6 @@ class SGLangConfig:
|
|||
raise ValueError(
|
||||
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
|
||||
)
|
||||
|
||||
if version_less_than_0_4_4:
|
||||
args.pop("log_requests_level")
|
||||
if version_less_than_0_4_3:
|
||||
|
@ -449,7 +445,6 @@ class SGLangConfig:
|
|||
args.pop("enable_memory_saver")
|
||||
args.pop("allow_auto_truncate")
|
||||
args.pop("file_storage_path")
|
||||
|
||||
flags = []
|
||||
for k, v in args.items():
|
||||
if v is None or v is False or v == "":
|
||||
|
@ -506,11 +501,6 @@ class InferenceEngineConfig:
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangEngineConfig:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Timer:
|
||||
experiment_name: str = MISSING
|
||||
|
@ -807,7 +797,6 @@ def parse_cli_args(argv: List[str]):
|
|||
"--config", help="The path of the main configuration file", required=True
|
||||
)
|
||||
args, overrides = parser.parse_known_args(argv)
|
||||
|
||||
# Initialize hydra config
|
||||
config_file = Path(args.config).absolute()
|
||||
assert config_file.exists()
|
||||
|
@ -834,7 +823,6 @@ def load_expr_config(argv: List[str], config_cls):
|
|||
cfg = to_structured_cfg(cfg, config_cls=config_cls)
|
||||
cfg = OmegaConf.to_object(cfg)
|
||||
assert isinstance(cfg, BaseExperimentConfig)
|
||||
|
||||
# Setup environment
|
||||
from realhf.base import constants, name_resolve, names
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
|||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
|
@ -78,8 +79,8 @@ class TrainEngine(abc.ABC):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> Dict[str, float]:
|
||||
"""Update the model with a batch of data and a loss function."""
|
||||
raise NotImplementedError()
|
||||
|
@ -88,8 +89,8 @@ class TrainEngine(abc.ABC):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate the model using the forward pass and loss function."""
|
||||
raise NotImplementedError()
|
||||
|
@ -99,7 +100,7 @@ class TrainEngine(abc.ABC):
|
|||
self,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
|
||||
|
@ -136,6 +137,20 @@ class InferenceEngine(abc.ABC):
|
|||
"""Wait for a specified number of requests to complete, with a timeout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rollout_batch(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def prepare_batch(
|
||||
self,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
):
|
||||
"""Asynchronously submit and wait until a full batch is ready."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pause(self):
|
||||
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
|
||||
raise NotImplementedError()
|
||||
|
@ -143,9 +158,3 @@ class InferenceEngine(abc.ABC):
|
|||
def resume(self):
|
||||
"""Resume request submission for async rollout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rollout(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -481,7 +481,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
)
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
def rollout(
|
||||
def rollout_batch(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
|
|
|
@ -108,7 +108,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
|
|||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
}
|
||||
result = engine.rollout([data] * 2, workflow=workflow)
|
||||
result = engine.rollout_batch([data] * 2, workflow=workflow)
|
||||
assert isinstance(result, TensorDict)
|
||||
bs = result.batch_size
|
||||
assert bs == torch.Size([2 * n_samples])
|
||||
|
|
Loading…
Reference in New Issue