AReaL/arealite/api/cli_args.py

69 lines
2.2 KiB
Python

from dataclasses import dataclass, field, asdict
from typing import List
@dataclass
class MicroBatchSpec:
"""Specification for splitting micro-batches during training."""
n_mbs: int = field(
default=1,
metadata={
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
},
)
max_tokens_per_mb: int = field(
default=int(1e12),
metadata={
"help": "Maximum tokens per micro-batch. When set, n_mbs becomes the minimum number of micro-batches",
},
)
@classmethod
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
"""Create new spec with updated fields while maintaining Omegaconf compatibility."""
fields = dict(
n_mbs=mb_spec.n_mbs,
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
)
fields.update(kwargs)
return cls(**fields)
@dataclass
class GenerationHyperparameters:
"""Controls text generation behavior for RL training."""
n_samples: int = field(
default=1, metadata={"help": "Number of sequences to generate per prompt."}
)
max_new_tokens: int = field(
default=16384, metadata={"help": "Maximum number of tokens to generate."}
)
min_new_tokens: int = field(
default=0, metadata={"help": "Minimum number of tokens to generate."}
)
greedy: bool = field(
default=False,
metadata={"help": "Whether to use greedy decoding (max probability)."},
)
top_p: float = field(
default=1.0,
metadata={"help": "Nucleus sampling probability threshold (0.0, 1.0]."},
)
top_k: int = field(
default=int(1e8),
metadata={"help": "Number of highest probability tokens to consider."},
)
temperature: float = field(
default=1.0,
metadata={"help": "Sampling temperature. Higher values increase diversity."},
)
stop_token_ids: List[int] = field(
default_factory=list,
metadata={"help": "Stop generation when encoutering these token ids."},
)
def new(self, **kwargs):
args = asdict(self)
args.update(kwargs)
return GenerationHyperparameters(**args)