mirror of https://github.com/inclusionAI/AReaL
format
This commit is contained in:
parent
28c9479981
commit
b6e19dbf60
|
@ -1,6 +1,7 @@
|
|||
from dataclasses import dataclass, field, asdict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroBatchSpec:
|
||||
"""Specification for splitting micro-batches during training."""
|
||||
|
@ -28,6 +29,7 @@ class MicroBatchSpec:
|
|||
fields.update(kwargs)
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationHyperparameters:
|
||||
"""Controls text generation behavior for RL training."""
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
import abc
|
||||
from typing import Callable, Dict, List, Any, Optional
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
FinetuneSpec,
|
||||
WeightUpdateMeta,
|
||||
SaveLoadMeta,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import abc
|
||||
from typing import Any, Dict, List, Callable
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
|
||||
class Environment(abc.ABC):
|
||||
|
|
|
@ -7,13 +7,12 @@ import re
|
|||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMRequest:
|
||||
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
@ -44,7 +43,7 @@ class LLMResponse:
|
|||
@property
|
||||
def input_len(self) -> int:
|
||||
return len(self.input_tokens)
|
||||
|
||||
|
||||
@property
|
||||
def output_len(self) -> int:
|
||||
return len(self.output_tokens)
|
||||
|
@ -155,6 +154,7 @@ class AllocationMode:
|
|||
other_alloc.update({"gen": gen_alloc["*"]})
|
||||
return other_alloc
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightUpdateMeta:
|
||||
type: str
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import TYPE_CHECKING, Dict, Any
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from tensordict import TensorDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from tensordict import TensorDict
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.api.io_struct import LLMRequest
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.io_struct import LLMRequest
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
def __init__(
|
||||
|
|
Loading…
Reference in New Issue