mirror of https://github.com/inclusionAI/AReaL
87 lines
2.3 KiB
Python
87 lines
2.3 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
|
|
import os
|
|
from typing import *
|
|
|
|
import pytest
|
|
|
|
from realhf.api.quickstart.dataset import PromptAnswerDatasetConfig
|
|
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
|
from realhf.base import cluster, testing
|
|
from realhf.experiments.common.sft_exp import SFTConfig
|
|
from tests.experiments.utils import run_test_exp
|
|
from tests.fixtures import *
|
|
|
|
|
|
@pytest.fixture(params=["llama"])
|
|
def model_class(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
os.cpu_count() < 32 or testing.get_free_mem_gb() < 50,
|
|
reason="Testing with larger parallelization degrees requires more CPU cores and memory.",
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"dp,pp,tp",
|
|
[
|
|
(2, 2, 2),
|
|
(1, 2, 4),
|
|
(2, 4, 1),
|
|
(2, 1, 4),
|
|
(4, 2, 2),
|
|
(2, 4, 2),
|
|
(2, 2, 4),
|
|
],
|
|
)
|
|
def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
|
test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp)
|
|
|
|
|
|
# NOTE: we can't test v1 and v2 at the same time.
|
|
@pytest.mark.parametrize("use_v2_worker", [True])
|
|
@pytest.mark.parametrize(
|
|
"dp,pp,tp",
|
|
[
|
|
(1, 1, 1),
|
|
(2, 1, 1),
|
|
(1, 2, 1),
|
|
(1, 1, 2),
|
|
],
|
|
)
|
|
def test_sft(
|
|
tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp, use_v2_worker
|
|
):
|
|
|
|
# Setup experiment env. Should be done before any other operations.
|
|
log_root = tmp_path_factory.mktemp("sft")
|
|
cluster.spec.fileroot = str(log_root)
|
|
constants.set_experiment_trial_names(
|
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
|
)
|
|
|
|
minbs = 32
|
|
exp_cfg = SFTConfig(
|
|
experiment_name=testing._DEFAULT_EXPR_NAME,
|
|
trial_name=testing._DEFAULT_TRIAL_NAME,
|
|
mode="local",
|
|
allocation_mode=f"m{tp}d{dp}p{pp}",
|
|
n_nodes=1,
|
|
n_gpus_per_node=tp * dp * pp,
|
|
model=ModelTrainEvalConfig(
|
|
path=str(save_path),
|
|
init_from_scratch=True,
|
|
backend="mock_train",
|
|
),
|
|
dataset=PromptAnswerDatasetConfig(
|
|
train_path=str(save_path / "dataset.json"),
|
|
valid_path=str(save_path / "dataset.json"),
|
|
max_seqlen=128,
|
|
train_bs_n_seqs=minbs,
|
|
valid_bs_n_seqs=minbs,
|
|
fill_to_max_length=False,
|
|
),
|
|
)
|
|
|
|
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)
|