support fsdp engine and sglang remote engine

This commit is contained in:
bowei.fw 2025-07-09 16:45:29 +08:00
parent 15dfbe837c
commit a6bcab22ba
8 changed files with 37 additions and 24 deletions

View File

@ -322,23 +322,7 @@ class ExperimentSaveEvalControl:
total_train_epochs: int = field(
default=1, metadata={"help": "Total number of epochs to train the model."}
)
# Save control
save_freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Save frequency in epochs. None disables epoch-based saving."
},
)
save_freq_steps: Optional[int] = field(
default=None,
metadata={"help": "Save frequency in steps. None disables step-based saving."},
)
save_freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Save frequency in seconds. None disables time-based saving."
},
)
# Checkpointing control
ckpt_freq_epochs: Optional[int] = field(
default=None,
@ -395,6 +379,30 @@ class ExperimentSaveEvalControl:
)
@dataclass
class SaverConfig:
experiment_name: str
trial_name: str
fileroot: str
# Save control
freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Save frequency in epochs. None disables epoch-based saving."
},
)
freq_steps: Optional[int] = field(
default=None,
metadata={"help": "Save frequency in steps. None disables step-based saving."},
)
freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Save frequency in seconds. None disables time-based saving."
},
)
@dataclass
class WandBConfig:
mode: str = "disabled"

View File

@ -60,6 +60,10 @@ class FinetuneSpec:
# assuming drop_last
return self.total_train_epochs * (self.dataset_size // self.train_batch_size)
@property
def steps_per_epoch(self):
return self.dataset_size // self.train_batch_size
class AllocationType(enum.Enum):
DECOUPLED_vLLM = 1

View File

@ -4,7 +4,7 @@ from typing import Dict
import torch.distributed as dist
import wandb
from tensorboardX import SummaryWrite
from tensorboardX import SummaryWriter
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import TrainerConfig

View File

@ -1,6 +1,7 @@
import asyncio
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
@ -349,7 +350,7 @@ class RemoteSGLangEngine(InferenceEngine):
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["text"] += completions
payload["text"] += result["text"]
latency = time.perf_counter() - start_time

View File

@ -75,7 +75,6 @@ def sglang_server():
process.terminate()
@pytest.mark.skip("")
@pytest.mark.asyncio
async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
@ -99,7 +98,6 @@ async def test_remote_sglang_generate(sglang_server):
assert isinstance(resp.completions, str)
@pytest.mark.skip("")
@pytest.mark.parametrize("n_samples", [1, 2, 4])
def test_remote_sglang_rollout(sglang_server, n_samples):
from arealite.engine.sglang_remote import RemoteSGLangEngine

View File

@ -8,7 +8,7 @@ from arealite.utils.data import (
pad_and_stack_tensors_along_first_dim,
pad_sequences_to_tensors,
reorder_list,
split_packed_tensor_dict_into_mbs,
split_packed_tensor_dict_into_mb_list,
unpack_sequence,
)
@ -45,7 +45,7 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
packed_data = pack_tensor_dict(mock_padded_data)
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
split_result = split_packed_tensor_dict_into_mbs(packed_data, mb_spec)
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
# assert microbatch split result does not violate requirements

View File

@ -85,4 +85,6 @@ def main_sft():
if __name__ == "__main__":
main_sft()
args = parse_args()
trainer = SFTTrainer(args)
trainer.train()