AReaL/tests/experiments/test_sft.py

87 lines
2.3 KiB
Python

# Copyright 2025 Ant Group Inc.
import os
from typing import *
import pytest
from realhf.api.quickstart.dataset import PromptAnswerDatasetConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.base import cluster, testing
from realhf.experiments.common.sft_exp import SFTConfig
from tests.experiments.utils import run_test_exp
from tests.fixtures import *
@pytest.fixture(params=["llama"])
def model_class(request):
return request.param
@pytest.mark.skipif(
os.cpu_count() < 32 or testing.get_free_mem_gb() < 50,
reason="Testing with larger parallelization degrees requires more CPU cores and memory.",
)
@pytest.mark.parametrize(
"dp,pp,tp",
[
(2, 2, 2),
(1, 2, 4),
(2, 4, 1),
(2, 1, 4),
(4, 2, 2),
(2, 4, 2),
(2, 2, 4),
],
)
def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp)
# NOTE: we can't test v1 and v2 at the same time.
@pytest.mark.parametrize("use_v2_worker", [True])
@pytest.mark.parametrize(
"dp,pp,tp",
[
(1, 1, 1),
(2, 1, 1),
(1, 2, 1),
(1, 1, 2),
],
)
def test_sft(
tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp, use_v2_worker
):
# Setup experiment env. Should be done before any other operations.
log_root = tmp_path_factory.mktemp("sft")
cluster.spec.fileroot = str(log_root)
constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
)
minbs = 32
exp_cfg = SFTConfig(
experiment_name=testing._DEFAULT_EXPR_NAME,
trial_name=testing._DEFAULT_TRIAL_NAME,
mode="local",
allocation_mode=f"m{tp}d{dp}p{pp}",
n_nodes=1,
n_gpus_per_node=tp * dp * pp,
model=ModelTrainEvalConfig(
path=str(save_path),
init_from_scratch=True,
backend="mock_train",
),
dataset=PromptAnswerDatasetConfig(
train_path=str(save_path / "dataset.json"),
valid_path=str(save_path / "dataset.json"),
max_seqlen=128,
train_bs_n_seqs=minbs,
valid_bs_n_seqs=minbs,
fill_to_max_length=False,
),
)
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)