PullRequest: 339 [Fix] Fix some minor issues to pass all tests.

Merge branch fw/lite of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/339

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support fsdp engine and sglang remote engine
* minor fix
* .
This commit is contained in:
博惟 2025-07-09 16:51:26 +08:00 committed by 晓雷
parent 15dfbe837c
commit 7be4ab0d18
6 changed files with 9 additions and 6 deletions

View File

@ -60,6 +60,10 @@ class FinetuneSpec:
# assuming drop_last
return self.total_train_epochs * (self.dataset_size // self.train_batch_size)
@property
def steps_per_epoch(self):
return self.dataset_size // self.train_batch_size
class AllocationType(enum.Enum):
DECOUPLED_vLLM = 1

View File

@ -4,7 +4,7 @@ from typing import Dict
import torch.distributed as dist
import wandb
from tensorboardX import SummaryWrite
from tensorboardX import SummaryWriter
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import TrainerConfig

View File

@ -1,6 +1,7 @@
import asyncio
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
@ -349,7 +350,7 @@ class RemoteSGLangEngine(InferenceEngine):
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["text"] += completions
payload["text"] += result["text"]
latency = time.perf_counter() - start_time

View File

@ -75,7 +75,6 @@ def sglang_server():
process.terminate()
@pytest.mark.skip("")
@pytest.mark.asyncio
async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
@ -99,7 +98,6 @@ async def test_remote_sglang_generate(sglang_server):
assert isinstance(resp.completions, str)
@pytest.mark.skip("")
@pytest.mark.parametrize("n_samples", [1, 2, 4])
def test_remote_sglang_rollout(sglang_server, n_samples):
from arealite.engine.sglang_remote import RemoteSGLangEngine

View File

@ -8,7 +8,7 @@ from arealite.utils.data import (
pad_and_stack_tensors_along_first_dim,
pad_sequences_to_tensors,
reorder_list,
split_packed_tensor_dict_into_mbs,
split_packed_tensor_dict_into_mb_list,
unpack_sequence,
)
@ -45,7 +45,7 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
packed_data = pack_tensor_dict(mock_padded_data)
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
split_result = split_packed_tensor_dict_into_mbs(packed_data, mb_spec)
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
# assert microbatch split result does not violate requirements