mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
commit
d8bd161b1d
|
@ -92,7 +92,7 @@ def main_grpo():
|
||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
# synchronous rollout
|
# synchronous rollout
|
||||||
rollout_batch = rollout.rollout(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
||||||
# or asynchronous rollout with filtering and off-policyness control
|
# or asynchronous rollout with filtering and off-policyness control
|
||||||
# rollout_batch = rollout.prepare_batch(batch,
|
# rollout_batch = rollout.prepare_batch(batch,
|
||||||
# workflow=MyRolloutWorkflow(rollout_config.workflow),
|
# workflow=MyRolloutWorkflow(rollout_config.workflow),
|
||||||
|
@ -697,7 +697,7 @@ reward = TrainController(Critic())
|
||||||
rollout_controller = RolloutController(...)
|
rollout_controller = RolloutController(...)
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
for _ in range(steps_per_epoch):
|
for _ in range(steps_per_epoch):
|
||||||
data = rollout_controller.rollout(prompt)
|
data = rollout_controller.rollout_batch(prompt)
|
||||||
data['reward'] = reward.compute_values(data)
|
data['reward'] = reward.compute_values(data)
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
|
@ -339,7 +339,6 @@ class SGLangConfig:
|
||||||
model_path: str = ""
|
model_path: str = ""
|
||||||
random_seed: int = 1
|
random_seed: int = 1
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
|
|
||||||
disable_cuda_graph: bool = False
|
disable_cuda_graph: bool = False
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
disable_cuda_graph_padding: bool = False
|
disable_cuda_graph_padding: bool = False
|
||||||
|
@ -375,10 +374,8 @@ class SGLangConfig:
|
||||||
schedule_policy: str = "lpm"
|
schedule_policy: str = "lpm"
|
||||||
schedule_conservativeness: float = 1.0
|
schedule_conservativeness: float = 1.0
|
||||||
cpu_offload_gb: int = 0
|
cpu_offload_gb: int = 0
|
||||||
|
|
||||||
dtype: str = "float16"
|
dtype: str = "float16"
|
||||||
kv_cache_dtype: str = "auto"
|
kv_cache_dtype: str = "auto"
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
log_level: str = "warning"
|
log_level: str = "warning"
|
||||||
log_level_http: Optional[str] = "warning"
|
log_level_http: Optional[str] = "warning"
|
||||||
|
@ -439,7 +436,6 @@ class SGLangConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
|
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
|
||||||
)
|
)
|
||||||
|
|
||||||
if version_less_than_0_4_4:
|
if version_less_than_0_4_4:
|
||||||
args.pop("log_requests_level")
|
args.pop("log_requests_level")
|
||||||
if version_less_than_0_4_3:
|
if version_less_than_0_4_3:
|
||||||
|
@ -449,7 +445,6 @@ class SGLangConfig:
|
||||||
args.pop("enable_memory_saver")
|
args.pop("enable_memory_saver")
|
||||||
args.pop("allow_auto_truncate")
|
args.pop("allow_auto_truncate")
|
||||||
args.pop("file_storage_path")
|
args.pop("file_storage_path")
|
||||||
|
|
||||||
flags = []
|
flags = []
|
||||||
for k, v in args.items():
|
for k, v in args.items():
|
||||||
if v is None or v is False or v == "":
|
if v is None or v is False or v == "":
|
||||||
|
@ -506,11 +501,6 @@ class InferenceEngineConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SGLangEngineConfig:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _Timer:
|
class _Timer:
|
||||||
experiment_name: str = MISSING
|
experiment_name: str = MISSING
|
||||||
|
@ -807,7 +797,6 @@ def parse_cli_args(argv: List[str]):
|
||||||
"--config", help="The path of the main configuration file", required=True
|
"--config", help="The path of the main configuration file", required=True
|
||||||
)
|
)
|
||||||
args, overrides = parser.parse_known_args(argv)
|
args, overrides = parser.parse_known_args(argv)
|
||||||
|
|
||||||
# Initialize hydra config
|
# Initialize hydra config
|
||||||
config_file = Path(args.config).absolute()
|
config_file = Path(args.config).absolute()
|
||||||
assert config_file.exists()
|
assert config_file.exists()
|
||||||
|
@ -834,7 +823,6 @@ def load_expr_config(argv: List[str], config_cls):
|
||||||
cfg = to_structured_cfg(cfg, config_cls=config_cls)
|
cfg = to_structured_cfg(cfg, config_cls=config_cls)
|
||||||
cfg = OmegaConf.to_object(cfg)
|
cfg = OmegaConf.to_object(cfg)
|
||||||
assert isinstance(cfg, BaseExperimentConfig)
|
assert isinstance(cfg, BaseExperimentConfig)
|
||||||
|
|
||||||
# Setup environment
|
# Setup environment
|
||||||
from realhf.base import constants, name_resolve, names
|
from realhf.base import constants, name_resolve, names
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
|
||||||
from arealite.api.io_struct import (
|
from arealite.api.io_struct import (
|
||||||
FinetuneSpec,
|
FinetuneSpec,
|
||||||
|
@ -78,8 +79,8 @@ class TrainEngine(abc.ABC):
|
||||||
def train_batch(
|
def train_batch(
|
||||||
self,
|
self,
|
||||||
input_: TensorDict,
|
input_: TensorDict,
|
||||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||||
loss_weight_fn: Callable[[Dict], float],
|
loss_weight_fn: Callable[[TensorDict], float],
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""Update the model with a batch of data and a loss function."""
|
"""Update the model with a batch of data and a loss function."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -88,8 +89,8 @@ class TrainEngine(abc.ABC):
|
||||||
def eval_batch(
|
def eval_batch(
|
||||||
self,
|
self,
|
||||||
input_: TensorDict,
|
input_: TensorDict,
|
||||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||||
loss_weight_fn: Callable[[Dict], float],
|
loss_weight_fn: Callable[[TensorDict], float],
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
"""Evaluate the model using the forward pass and loss function."""
|
"""Evaluate the model using the forward pass and loss function."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -99,7 +100,7 @@ class TrainEngine(abc.ABC):
|
||||||
self,
|
self,
|
||||||
input_: TensorDict,
|
input_: TensorDict,
|
||||||
output_seqlens: List[List[int]] | None = None,
|
output_seqlens: List[List[int]] | None = None,
|
||||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
||||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||||
) -> Any | None:
|
) -> Any | None:
|
||||||
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
|
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
|
||||||
|
@ -136,6 +137,20 @@ class InferenceEngine(abc.ABC):
|
||||||
"""Wait for a specified number of requests to complete, with a timeout."""
|
"""Wait for a specified number of requests to complete, with a timeout."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def rollout_batch(
|
||||||
|
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||||
|
) -> TensorDict:
|
||||||
|
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def prepare_batch(
|
||||||
|
self,
|
||||||
|
dataloader: StatefulDataLoader,
|
||||||
|
workflow: "RolloutWorkflow",
|
||||||
|
):
|
||||||
|
"""Asynchronously submit and wait until a full batch is ready."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def pause(self):
|
def pause(self):
|
||||||
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
|
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -143,9 +158,3 @@ class InferenceEngine(abc.ABC):
|
||||||
def resume(self):
|
def resume(self):
|
||||||
"""Resume request submission for async rollout."""
|
"""Resume request submission for async rollout."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def rollout(
|
|
||||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
|
||||||
) -> TensorDict:
|
|
||||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
|
@ -481,7 +481,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
)
|
)
|
||||||
return concat_padded_tensors(results)
|
return concat_padded_tensors(results)
|
||||||
|
|
||||||
def rollout(
|
def rollout_batch(
|
||||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||||
) -> TensorDict:
|
) -> TensorDict:
|
||||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||||
|
|
|
@ -108,7 +108,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
|
||||||
data = {
|
data = {
|
||||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||||
}
|
}
|
||||||
result = engine.rollout([data] * 2, workflow=workflow)
|
result = engine.rollout_batch([data] * 2, workflow=workflow)
|
||||||
assert isinstance(result, TensorDict)
|
assert isinstance(result, TensorDict)
|
||||||
bs = result.batch_size
|
bs = result.batch_size
|
||||||
assert bs == torch.Size([2 * n_samples])
|
assert bs == torch.Size([2 * n_samples])
|
||||||
|
|
Loading…
Reference in New Issue