AReaL/arealite/api/rollout_api.py

149 lines
4.7 KiB
Python
Executable File

# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import abc
import functools
from dataclasses import dataclass
from typing import Any, Callable, Optional, SupportsFloat
from gymnasium import Env
from gymnasium.core import ActType, ObsType
from gymnasium.utils import seeding
from arealite.api.cli_args import (
GenerationHyperparameters,
RolloutCollectorConfig,
TrainingArgs,
)
from arealite.api.io_struct import AgentInferInput, AgentInferOutput, Trajectory
from arealite.api.llm_client_api import LLMClient
class Agent(abc.ABC):
def __init__(self, args: TrainingArgs):
self.args = args
async def aact(self, inp: AgentInferInput) -> AgentInferOutput:
"""Async version of act. Given an observation, return an action and data used for RL training."""
raise NotImplementedError()
async def areset(self) -> None:
"""Async version of reset. Resets the agent's memory."""
raise NotImplementedError()
# Re-export the gymnasium environment class
class Environment(abc.ABC, Env):
def __init__(self, args: TrainingArgs):
self.args = args
@abc.abstractmethod
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
raise NotImplementedError()
@abc.abstractmethod
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
# Initialize the RNG if the seed is manually passed
if seed is not None:
self._np_random, self._np_random_seed = seeding.np_random(seed)
class RolloutCollector(abc.ABC):
def __init__(
self,
args: TrainingArgs,
config: RolloutCollectorConfig,
agent: Agent | None = None,
env: Environment | None = None,
reward_func: Callable | None = None,
):
self.args = args
self.config = config
# Used in agentic scenarios
self.agent = agent
self.env = env
# Used in RLVR
self.reward_func = reward_func
async def arun_episode(
self,
llm_client: LLMClient,
gconfig: GenerationHyperparameters,
env_option: Optional[Any] = None,
seed: Optional[int] = None,
) -> Trajectory:
"""Async version of run_episode. Run a single episode and return the trajectory."""
raise NotImplementedError()
@dataclass
class RolloutCollectorFactory:
args: TrainingArgs
def make_collector(self, config: RolloutCollectorConfig) -> RolloutCollector:
if config.type == "rlvr":
from arealite.impl.rlvr.rlvr_collector import RlvrCollector
rlvr_config = config.rlvr
assert rlvr_config is not None
if rlvr_config.reward_type == "areal-math":
from arealite.impl.rlvr.rewards.areal_math import math_reward
reward_fn = functools.partial(
math_reward, dataset_path=rlvr_config.solution_path
)
elif rlvr_config.reward_type == "areal-code":
from arealite.impl.rlvr.rewards.areal_code import code_reward
reward_fn = functools.partial(
code_reward, dataset_path=rlvr_config.solution_path
)
elif rlvr_config.reward_type == "gsm8k":
from arealite.impl.rlvr.rewards.gsm8k import (
gsm8k_reward_fn as reward_fn,
)
elif rlvr_config.reward_type == "clevr_count_70k":
from arealite.impl.rlvr.rewards.clevr_count_70k import (
clevr_count_70k_reward_fn as reward_fn,
)
else:
raise NotImplementedError(
f"Unknown reward type: {rlvr_config.reward_type}"
)
return RlvrCollector(
self.args,
config=config,
reward_fn=reward_fn,
)
if config.type == "math_code_single_step":
from arealite.impl.agentic.math_code_single_step import (
MathCodeAgent,
MathCodeSingleStepCollector,
MathCodeSingleStepEnv,
)
agent = MathCodeAgent(self.args)
env = MathCodeSingleStepEnv(
self.args,
solution_path=config.math_code_single_step.solution_path,
)
return MathCodeSingleStepCollector(
self.args,
config=config,
agent=agent,
env=env,
)
raise NotImplementedError(f"Unknown agent type: {config.type}")