mirror of https://github.com/inclusionAI/AReaL
support fsdp engine and sglang remote engine
This commit is contained in:
parent
15dfbe837c
commit
a6bcab22ba
|
@ -322,23 +322,7 @@ class ExperimentSaveEvalControl:
|
|||
total_train_epochs: int = field(
|
||||
default=1, metadata={"help": "Total number of epochs to train the model."}
|
||||
)
|
||||
# Save control
|
||||
save_freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Save frequency in epochs. None disables epoch-based saving."
|
||||
},
|
||||
)
|
||||
save_freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Save frequency in steps. None disables step-based saving."},
|
||||
)
|
||||
save_freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Save frequency in seconds. None disables time-based saving."
|
||||
},
|
||||
)
|
||||
|
||||
# Checkpointing control
|
||||
ckpt_freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
|
@ -395,6 +379,30 @@ class ExperimentSaveEvalControl:
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaverConfig:
|
||||
experiment_name: str
|
||||
trial_name: str
|
||||
fileroot: str
|
||||
# Save control
|
||||
freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Save frequency in epochs. None disables epoch-based saving."
|
||||
},
|
||||
)
|
||||
freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Save frequency in steps. None disables step-based saving."},
|
||||
)
|
||||
freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Save frequency in seconds. None disables time-based saving."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
mode: str = "disabled"
|
||||
|
|
|
@ -60,6 +60,10 @@ class FinetuneSpec:
|
|||
# assuming drop_last
|
||||
return self.total_train_epochs * (self.dataset_size // self.train_batch_size)
|
||||
|
||||
@property
|
||||
def steps_per_epoch(self):
|
||||
return self.dataset_size // self.train_batch_size
|
||||
|
||||
|
||||
class AllocationType(enum.Enum):
|
||||
DECOUPLED_vLLM = 1
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict
|
|||
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from tensorboardX import SummaryWrite
|
||||
from tensorboardX import SummaryWriter
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import TrainerConfig
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
@ -349,7 +350,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["text"] += completions
|
||||
payload["text"] += result["text"]
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
|
|
|
@ -75,7 +75,6 @@ def sglang_server():
|
|||
process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.skip("")
|
||||
@pytest.mark.asyncio
|
||||
async def test_remote_sglang_generate(sglang_server):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
@ -99,7 +98,6 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
assert isinstance(resp.completions, str)
|
||||
|
||||
|
||||
@pytest.mark.skip("")
|
||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||
def test_remote_sglang_rollout(sglang_server, n_samples):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
|
|
@ -8,7 +8,7 @@ from arealite.utils.data import (
|
|||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_sequences_to_tensors,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mbs,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
)
|
||||
|
||||
|
@ -45,7 +45,7 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
|
|||
packed_data = pack_tensor_dict(mock_padded_data)
|
||||
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
|
||||
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
|
||||
split_result = split_packed_tensor_dict_into_mbs(packed_data, mb_spec)
|
||||
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
|
||||
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
|
||||
|
||||
# assert microbatch split result does not violate requirements
|
||||
|
|
|
@ -85,4 +85,6 @@ def main_sft():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_sft()
|
||||
args = parse_args()
|
||||
trainer = SFTTrainer(args)
|
||||
trainer.train()
|
||||
|
|
Loading…
Reference in New Issue