This commit is contained in:
bowei.fw 2025-07-14 11:27:05 +08:00
commit 037adedc70
12 changed files with 117 additions and 22 deletions

View File

@ -1,5 +1,18 @@
# AReaL v1.0.0 Design Doc
---
Update 20250710
SFT example:
```bash
torchrun --nnodes 1 --nproc-per-node 8 examples/arealite/gsm8k_sft.py --config examples/arealite/configs/gsm8k_sft.yaml
```
---
We will provide both single-controller and SPMD user interfaces. The SPMD interface will be delivered with AReaLite, which is the paradigm most users are familiar with, just like using `torchrun` or `deepspeed`. However, this paradigm may lack some flexibility over global scheduling and control. To unlock the full potential with customized distributed execution, we will also provide a single-controller mode just like using Ray --- but our scheduler backend will not be restricted to Ray. Our code will be able to run with any scheduler in the cluster, such as native SLURM and K8S.
However, we want the user code to stay the same for both modes. The following is a simple usage example:
@ -737,4 +750,4 @@ dataloader = StatefulDataLoader(
)
for data in dataloader:
assert isinstance(data, list)
```
```

View File

@ -4,6 +4,9 @@ from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import uvloop
uvloop.install()
from hydra import compose as hydra_compose
from hydra import initialize as hydra_init
from omegaconf import MISSING, OmegaConf
@ -138,6 +141,61 @@ class OptimizerConfig:
)
@dataclass
class OptimizerConfig:
"""Configuration for model optimization during training.
Note:
Set type to "empty" for models that won't be trained.
"""
type: str = field(
default="adam",
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
)
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
min_lr_ratio: float = field(
default=0.0,
metadata={
"help": "Minimum learning rate ratio after annealing",
},
)
lr_scheduler_type: str = field(
default="constant",
metadata={
"help": "Learning rate scheduler type",
"choices": ["linear", "cosine", "constant"],
},
)
warmup_steps_proportion: float = field(
default=0.001,
metadata={
"help": "Proportion of training steps for warmup",
},
)
offload: bool = field(
default=False, metadata={"help": "Enable optimizer state offloading"}
)
initial_loss_scale: float = field(
default=2**32, metadata={"help": "Initial loss scaling factor"}
)
min_loss_scale: float = field(
default=1.0, metadata={"help": "Minimum loss scaling factor"}
)
loss_scale_window: float = field(
default=5, metadata={"help": "Window size for loss scaling adjustment"}
)
hysteresis: int = field(
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
)
gradient_clipping: float = field(
default=1.0, metadata={"help": "Gradient clipping threshold"}
)
@dataclass
class FSDPWrapPolicy:
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
@ -181,7 +239,7 @@ class TrainEngineConfig:
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
# Training Backend Configuration
disable_dropout: bool = False
disable_dropout: bool = field(default=False)
gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"}
)

View File

@ -49,7 +49,7 @@ from arealite.utils.fsdp import (
from arealite.utils.model import disable_dropout_in_model
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version, constants
from realhf.base import constants, logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine")
@ -91,7 +91,9 @@ class FSDPEngine(TrainEngine):
"""Initialize distributed communication and model."""
if not dist.is_initialized():
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
dist.init_process_group(backend="nccl", timeout=constants.NCCL_DEFAULT_TIMEOUT)
dist.init_process_group(
backend="nccl", timeout=constants.NCCL_DEFAULT_TIMEOUT
)
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
@ -107,6 +109,9 @@ class FSDPEngine(TrainEngine):
with torch.device("cuda"):
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
@ -344,6 +349,10 @@ class FSDPEngine(TrainEngine):
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(
@ -372,7 +381,7 @@ class FSDPEngine(TrainEngine):
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
outputs = self.model(**padded_mb_input, use_cache=False)
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
@ -426,7 +435,7 @@ class FSDPEngine(TrainEngine):
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input, use_cache=False)
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
@ -458,7 +467,7 @@ class FSDPEngine(TrainEngine):
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input, use_cache=False)
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits

Binary file not shown.

Before

Width:  |  Height:  |  Size: 163 KiB

After

Width:  |  Height:  |  Size: 2.7 KiB

View File

@ -6,15 +6,16 @@ GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"}
echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA"
# If there is already an image named areal-env, skip.
if docker images --format '{{.Repository}}:{{.Tag}}' | grep -q 'areal-env:latest'; then
RUN_ID="areal-$GIT_COMMIT_SHA"
cd "/tmp/$RUN_ID"
# If there is already an image for the current environment, skip the build.
ENV_SHA=$(sha256sum pyproject.toml | awk '{print $1}')
if docker images --format '{{.Repository}}:{{.Tag}}' | grep -q "areal-env:$ENV_SHA"; then
echo "Image areal-env already exists, skipping build."
exit 0
fi
RUN_ID="areal-$GIT_COMMIT_SHA"
cd "/tmp/$RUN_ID"
if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then
docker rm -f $RUN_ID
fi
@ -35,5 +36,5 @@ docker run \
mv ./sglang /sglang
" || { docker rm -f $RUN_ID; exit 1; }
docker commit $RUN_ID areal-env:latest
docker commit $RUN_ID "areal-env:$ENV_SHA"
docker rm -f $RUN_ID

View File

@ -13,13 +13,14 @@ if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then
docker rm -f $RUN_ID
fi
ENV_SHA=$(sha256sum pyproject.toml | awk '{print $1}')
docker run \
--name $RUN_ID \
--gpus all \
--shm-size=8g \
-v $(pwd):/workspace \
-w /workspace \
areal-env:latest \
"areal-env:$ENV_SHA" \
bash -c "
mv /sglang ./sglang
HF_ENDPOINT=https://hf-mirror.com python -m pytest -s arealite/

View File

@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
return result["result"], result["info"]
def code_verify(id2info, generateds, query_ids, debug=False):
def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False):
assert len(generateds) == len(query_ids)
problems = [id2info[qid] for qid in query_ids]
@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False):
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
run_results = []
num_process = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
if max_workers is None:
max_workers = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(max_workers) as executor:
run_results = executor.map(call_verify, *zip(*infer_args))
for run_result in run_results:

View File

@ -53,9 +53,9 @@ dependencies = [
"hydra-core==1.4.0.dev1",
"packaging",
"tabulate",
"gymnasium>=1.1.1",
"torchdata",
"autoflake",
"gymnasium",
"tensordict",
# Monitoring and logging

View File

@ -8,6 +8,7 @@ import os
import random
import time
from contextlib import contextmanager
from functools import lru_cache
# NOTE: We don't sue wildcard importing here because the type
# `Sequence` has a very similar name to `SequenceSample`.
@ -47,6 +48,7 @@ logger = logging.getLogger("api.data")
RL_TASKS = ["math", "code", "rlhf", "stem"]
@lru_cache(maxsize=8)
def load_hf_tokenizer(
model_name_or_path: str,
fast_tokenizer=True,

View File

@ -7,7 +7,7 @@ from typing import List, Union
import regex
from latex2sympy2 import latex2sympy
from pebble import ProcessPool
from pebble import ProcessExpired, ProcessPool
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False):
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
else: # use the last number
if use_last_number:
pattern = r"-?\d*\.?\d+"
pattern = "-?\d*\.?\d+"
pred = re.findall(pattern, pred_str.replace(",", ""))
if len(pred) >= 1:
pred = pred[-1]
@ -836,6 +837,12 @@ def parse_lines_in_parallel(
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
labels.append(label)
return labels

View File

@ -57,6 +57,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
self.id2info,
answers,
[qid for _ in range(group_size)],
max_workers=1,
)
elif cur_task == "code":
answers = [extract_code(x) for x in answers]
@ -65,6 +66,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
self.id2info,
answers,
[qid for _ in range(group_size)],
max_workers=1,
)
else:
raise NotImplementedError()

View File

@ -69,8 +69,8 @@ word2number
Pebble
timeout-decorator
prettytable
gymnasium>=1.1.1
swanlab[dashboard]
torchdata
autoflake
gymnasium
tensordict
tensordict