This commit is contained in:
garrett4wade 2025-07-07 14:02:59 +08:00
parent 28c9479981
commit b6e19dbf60
6 changed files with 22 additions and 15 deletions

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass, field, asdict
from dataclasses import asdict, dataclass, field
from typing import List
@dataclass
class MicroBatchSpec:
"""Specification for splitting micro-batches during training."""
@ -28,6 +29,7 @@ class MicroBatchSpec:
fields.update(kwargs)
return cls(**fields)
@dataclass
class GenerationHyperparameters:
"""Controls text generation behavior for RL training."""

View File

@ -1,15 +1,17 @@
import abc
from typing import Callable, Dict, List, Any, Optional
import torch
from dataclasses import dataclass, field
from concurrent.futures import Future
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
import torch
from arealite.api.cli_args import MicroBatchSpec
from arealite.api.io_struct import (
FinetuneSpec,
LLMRequest,
LLMResponse,
FinetuneSpec,
WeightUpdateMeta,
SaveLoadMeta,
WeightUpdateMeta,
)

View File

@ -1,5 +1,5 @@
import abc
from typing import Any, Dict, List, Callable
from typing import Any, Callable, Dict, List
class Environment(abc.ABC):

View File

@ -7,13 +7,12 @@ import re
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
@dataclass
class LLMRequest:
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
@ -44,7 +43,7 @@ class LLMResponse:
@property
def input_len(self) -> int:
return len(self.input_tokens)
@property
def output_len(self) -> int:
return len(self.output_tokens)
@ -155,6 +154,7 @@ class AllocationMode:
other_alloc.update({"gen": gen_alloc["*"]})
return other_alloc
@dataclass
class WeightUpdateMeta:
type: str

View File

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Dict, Any
from typing import TYPE_CHECKING, Any, Dict
from tensordict import TensorDict
if TYPE_CHECKING:

View File

@ -1,11 +1,13 @@
from tensordict import TensorDict
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.workflow_api import RolloutWorkflow
from arealite.api.io_struct import LLMRequest
import uuid
import torch
from tensordict import TensorDict
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow
class RLVRWorkflow(RolloutWorkflow):
def __init__(