mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
commit
037adedc70
|
@ -1,5 +1,18 @@
|
|||
# AReaL v1.0.0 Design Doc
|
||||
|
||||
---
|
||||
|
||||
Update 20250710
|
||||
|
||||
SFT example:
|
||||
|
||||
```bash
|
||||
torchrun --nnodes 1 --nproc-per-node 8 examples/arealite/gsm8k_sft.py --config examples/arealite/configs/gsm8k_sft.yaml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
We will provide both single-controller and SPMD user interfaces. The SPMD interface will be delivered with AReaLite, which is the paradigm most users are familiar with, just like using `torchrun` or `deepspeed`. However, this paradigm may lack some flexibility over global scheduling and control. To unlock the full potential with customized distributed execution, we will also provide a single-controller mode just like using Ray --- but our scheduler backend will not be restricted to Ray. Our code will be able to run with any scheduler in the cluster, such as native SLURM and K8S.
|
||||
|
||||
However, we want the user code to stay the same for both modes. The following is a simple usage example:
|
||||
|
@ -737,4 +750,4 @@ dataloader = StatefulDataLoader(
|
|||
)
|
||||
for data in dataloader:
|
||||
assert isinstance(data, list)
|
||||
```
|
||||
```
|
||||
|
|
|
@ -4,6 +4,9 @@ from dataclasses import asdict, dataclass, field
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import uvloop
|
||||
|
||||
uvloop.install()
|
||||
from hydra import compose as hydra_compose
|
||||
from hydra import initialize as hydra_init
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
|
@ -138,6 +141,61 @@ class OptimizerConfig:
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
"""Configuration for model optimization during training.
|
||||
Note:
|
||||
Set type to "empty" for models that won't be trained.
|
||||
"""
|
||||
|
||||
type: str = field(
|
||||
default="adam",
|
||||
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
|
||||
)
|
||||
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
|
||||
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
|
||||
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
|
||||
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
|
||||
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
|
||||
min_lr_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "Minimum learning rate ratio after annealing",
|
||||
},
|
||||
)
|
||||
lr_scheduler_type: str = field(
|
||||
default="constant",
|
||||
metadata={
|
||||
"help": "Learning rate scheduler type",
|
||||
"choices": ["linear", "cosine", "constant"],
|
||||
},
|
||||
)
|
||||
warmup_steps_proportion: float = field(
|
||||
default=0.001,
|
||||
metadata={
|
||||
"help": "Proportion of training steps for warmup",
|
||||
},
|
||||
)
|
||||
offload: bool = field(
|
||||
default=False, metadata={"help": "Enable optimizer state offloading"}
|
||||
)
|
||||
initial_loss_scale: float = field(
|
||||
default=2**32, metadata={"help": "Initial loss scaling factor"}
|
||||
)
|
||||
min_loss_scale: float = field(
|
||||
default=1.0, metadata={"help": "Minimum loss scaling factor"}
|
||||
)
|
||||
loss_scale_window: float = field(
|
||||
default=5, metadata={"help": "Window size for loss scaling adjustment"}
|
||||
)
|
||||
hysteresis: int = field(
|
||||
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
|
||||
)
|
||||
gradient_clipping: float = field(
|
||||
default=1.0, metadata={"help": "Gradient clipping threshold"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPWrapPolicy:
|
||||
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
|
||||
|
@ -181,7 +239,7 @@ class TrainEngineConfig:
|
|||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
# Training Backend Configuration
|
||||
disable_dropout: bool = False
|
||||
disable_dropout: bool = field(default=False)
|
||||
gradient_checkpointing: bool = field(
|
||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||
)
|
||||
|
|
|
@ -49,7 +49,7 @@ from arealite.utils.fsdp import (
|
|||
from arealite.utils.model import disable_dropout_in_model
|
||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, name_resolve, names, pkg_version, constants
|
||||
from realhf.base import constants, logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("FSDPEngine")
|
||||
|
||||
|
@ -91,7 +91,9 @@ class FSDPEngine(TrainEngine):
|
|||
"""Initialize distributed communication and model."""
|
||||
if not dist.is_initialized():
|
||||
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
|
||||
dist.init_process_group(backend="nccl", timeout=constants.NCCL_DEFAULT_TIMEOUT)
|
||||
dist.init_process_group(
|
||||
backend="nccl", timeout=constants.NCCL_DEFAULT_TIMEOUT
|
||||
)
|
||||
|
||||
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
|
@ -107,6 +109,9 @@ class FSDPEngine(TrainEngine):
|
|||
with torch.device("cuda"):
|
||||
if self.config.init_from_scratch:
|
||||
# initialize scratch model from config
|
||||
# NOTE: VLM cannot directly load state dict using this
|
||||
# random initialized model, so otherwise we call
|
||||
# from_pretrained rather than loading weights into this random model.
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
|
@ -344,6 +349,10 @@ class FSDPEngine(TrainEngine):
|
|||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
mb_list = unsqueeze_mb_list(mb_list)
|
||||
# FIXME: the resulting max_seqlen is a tensor rather than an integer
|
||||
for mb in mb_list.mbs:
|
||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
|
@ -372,7 +381,7 @@ class FSDPEngine(TrainEngine):
|
|||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
|
||||
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
|
@ -426,7 +435,7 @@ class FSDPEngine(TrainEngine):
|
|||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
|
@ -458,7 +467,7 @@ class FSDPEngine(TrainEngine):
|
|||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input, use_cache=False)
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 163 KiB After Width: | Height: | Size: 2.7 KiB |
|
@ -6,15 +6,16 @@ GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"}
|
|||
|
||||
echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA"
|
||||
|
||||
# If there is already an image named areal-env, skip.
|
||||
if docker images --format '{{.Repository}}:{{.Tag}}' | grep -q 'areal-env:latest'; then
|
||||
RUN_ID="areal-$GIT_COMMIT_SHA"
|
||||
cd "/tmp/$RUN_ID"
|
||||
|
||||
# If there is already an image for the current environment, skip the build.
|
||||
ENV_SHA=$(sha256sum pyproject.toml | awk '{print $1}')
|
||||
if docker images --format '{{.Repository}}:{{.Tag}}' | grep -q "areal-env:$ENV_SHA"; then
|
||||
echo "Image areal-env already exists, skipping build."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
RUN_ID="areal-$GIT_COMMIT_SHA"
|
||||
cd "/tmp/$RUN_ID"
|
||||
|
||||
if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then
|
||||
docker rm -f $RUN_ID
|
||||
fi
|
||||
|
@ -35,5 +36,5 @@ docker run \
|
|||
mv ./sglang /sglang
|
||||
" || { docker rm -f $RUN_ID; exit 1; }
|
||||
|
||||
docker commit $RUN_ID areal-env:latest
|
||||
docker commit $RUN_ID "areal-env:$ENV_SHA"
|
||||
docker rm -f $RUN_ID
|
||||
|
|
|
@ -13,13 +13,14 @@ if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then
|
|||
docker rm -f $RUN_ID
|
||||
fi
|
||||
|
||||
ENV_SHA=$(sha256sum pyproject.toml | awk '{print $1}')
|
||||
docker run \
|
||||
--name $RUN_ID \
|
||||
--gpus all \
|
||||
--shm-size=8g \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
areal-env:latest \
|
||||
"areal-env:$ENV_SHA" \
|
||||
bash -c "
|
||||
mv /sglang ./sglang
|
||||
HF_ENDPOINT=https://hf-mirror.com python -m pytest -s arealite/
|
||||
|
|
|
@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
|
|||
return result["result"], result["info"]
|
||||
|
||||
|
||||
def code_verify(id2info, generateds, query_ids, debug=False):
|
||||
def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False):
|
||||
assert len(generateds) == len(query_ids)
|
||||
problems = [id2info[qid] for qid in query_ids]
|
||||
|
||||
|
@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False):
|
|||
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
|
||||
|
||||
run_results = []
|
||||
num_process = max(1, os.cpu_count() // 8)
|
||||
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
|
||||
if max_workers is None:
|
||||
max_workers = max(1, os.cpu_count() // 8)
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers) as executor:
|
||||
run_results = executor.map(call_verify, *zip(*infer_args))
|
||||
|
||||
for run_result in run_results:
|
||||
|
|
|
@ -53,9 +53,9 @@ dependencies = [
|
|||
"hydra-core==1.4.0.dev1",
|
||||
"packaging",
|
||||
"tabulate",
|
||||
"gymnasium>=1.1.1",
|
||||
"torchdata",
|
||||
"autoflake",
|
||||
"gymnasium",
|
||||
"tensordict",
|
||||
|
||||
# Monitoring and logging
|
||||
|
|
|
@ -8,6 +8,7 @@ import os
|
|||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
|
||||
# NOTE: We don't sue wildcard importing here because the type
|
||||
# `Sequence` has a very similar name to `SequenceSample`.
|
||||
|
@ -47,6 +48,7 @@ logger = logging.getLogger("api.data")
|
|||
RL_TASKS = ["math", "code", "rlhf", "stem"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def load_hf_tokenizer(
|
||||
model_name_or_path: str,
|
||||
fast_tokenizer=True,
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import List, Union
|
|||
|
||||
import regex
|
||||
from latex2sympy2 import latex2sympy
|
||||
from pebble import ProcessPool
|
||||
from pebble import ProcessExpired, ProcessPool
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False):
|
|||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "")
|
||||
string = string.replace("%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
|
@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
|
|||
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
|
||||
else: # use the last number
|
||||
if use_last_number:
|
||||
pattern = r"-?\d*\.?\d+"
|
||||
pattern = "-?\d*\.?\d+"
|
||||
pred = re.findall(pattern, pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
pred = pred[-1]
|
||||
|
@ -836,6 +837,12 @@ def parse_lines_in_parallel(
|
|||
# print("[debug: timeout]")
|
||||
logger.warning(f"Timeout occurred while justifying the math answer.")
|
||||
x = (0, "timeout", "timeout")
|
||||
except ProcessExpired as e:
|
||||
logger.warning(f"Process terminated abnormally: {e}")
|
||||
x = (0, "error", "error")
|
||||
except Exception as e:
|
||||
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
|
||||
x = (0, "error", "error")
|
||||
label = label or x[0]
|
||||
labels.append(label)
|
||||
return labels
|
||||
|
|
|
@ -57,6 +57,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
|
|||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
max_workers=1,
|
||||
)
|
||||
elif cur_task == "code":
|
||||
answers = [extract_code(x) for x in answers]
|
||||
|
@ -65,6 +66,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
|
|||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
max_workers=1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -69,8 +69,8 @@ word2number
|
|||
Pebble
|
||||
timeout-decorator
|
||||
prettytable
|
||||
gymnasium>=1.1.1
|
||||
swanlab[dashboard]
|
||||
torchdata
|
||||
autoflake
|
||||
gymnasium
|
||||
tensordict
|
||||
tensordict
|
||||
|
|
Loading…
Reference in New Issue