AReaL/arealite/tests/test_sft.py

109 lines
2.9 KiB
Python
Executable File

"""Test script for FSDP Engine implementation."""
import os
from typing import Dict
import torch
from datasets import load_dataset
from arealite.api.cli_args import (
DatasetConfig,
DatasetPreprocessor,
EngineBackendConfig,
EngineConfig,
OptimizerConfig,
SFTTrainerConfig,
TrainerConfig,
TrainingArgs,
)
from arealite.api.dataset_api import DatasetFactory
from arealite.api.trainer_api import TrainerFactory
def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
"""Mock loss function for testing."""
return torch.mean(logits)
def mock_loss_weight_fn(logits: torch.Tensor, input_data: Dict) -> float:
"""Mock loss weight function for testing."""
return float(input_data["attention_mask"].sum())
def test_sft():
"""Test SFTTrainer"""
# environment variables for torch distributed
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7777"
train_dataset = DatasetConfig(
path="openai/gsm8k",
preprocessor=DatasetPreprocessor("gsm8k_sft"),
name="main",
split="train",
batch_size=8,
shuffle=True,
pin_memory=True,
)
valid_dataset = DatasetConfig(
path="openai/gsm8k",
preprocessor=DatasetPreprocessor("gsm8k_sft"),
name="main",
split="test",
batch_size=8,
shuffle=False,
pin_memory=True,
)
engine_config = EngineConfig(
path="/storage/openpsi/models/Qwen__Qwen3-1.7B/",
gradient_checkpointing=False,
optimizer=OptimizerConfig(),
backend=EngineBackendConfig(type="hf"),
)
sft_config = SFTTrainerConfig(
model=engine_config,
)
train_config = TrainerConfig(
type="sft",
sft=sft_config,
)
args = TrainingArgs(
experiment_name="test-sft",
trial_name="test",
mode="local",
n_nodes=1,
n_gpus_per_node=1,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
trainer=train_config,
)
rollout_controller = None
dataset_factory = DatasetFactory(args)
train_dataset = dataset_factory.make_dataset(args.train_dataset, 0, 1)
train_dataset = train_dataset.select(range(100))
valid_dataset = None
if args.valid_dataset is not None:
valid_dataset = dataset_factory.make_dataset(args.valid_dataset, 0, 1)
valid_dataset = valid_dataset.select(range(100))
if args.trainer is not None:
trainer_factory = TrainerFactory(args)
trainer = trainer_factory.make_trainer(
args.trainer,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
rollout_controller=rollout_controller,
)
trainer.train()
print("All tests passed!")
test_sft()