mirror of https://github.com/inclusionAI/AReaL
PullRequest: 339 [Fix] Fix some minor issues to pass all tests.
Merge branch fw/lite of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/339 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * .
This commit is contained in:
parent
15dfbe837c
commit
7be4ab0d18
|
@ -60,6 +60,10 @@ class FinetuneSpec:
|
|||
# assuming drop_last
|
||||
return self.total_train_epochs * (self.dataset_size // self.train_batch_size)
|
||||
|
||||
@property
|
||||
def steps_per_epoch(self):
|
||||
return self.dataset_size // self.train_batch_size
|
||||
|
||||
|
||||
class AllocationType(enum.Enum):
|
||||
DECOUPLED_vLLM = 1
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict
|
|||
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from tensorboardX import SummaryWrite
|
||||
from tensorboardX import SummaryWriter
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import TrainerConfig
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
@ -349,7 +350,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["text"] += completions
|
||||
payload["text"] += result["text"]
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
|
|
|
@ -75,7 +75,6 @@ def sglang_server():
|
|||
process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.skip("")
|
||||
@pytest.mark.asyncio
|
||||
async def test_remote_sglang_generate(sglang_server):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
@ -99,7 +98,6 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
assert isinstance(resp.completions, str)
|
||||
|
||||
|
||||
@pytest.mark.skip("")
|
||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||
def test_remote_sglang_rollout(sglang_server, n_samples):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
|
|
@ -8,7 +8,7 @@ from arealite.utils.data import (
|
|||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_sequences_to_tensors,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mbs,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
)
|
||||
|
||||
|
@ -45,7 +45,7 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
|
|||
packed_data = pack_tensor_dict(mock_padded_data)
|
||||
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
|
||||
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
|
||||
split_result = split_packed_tensor_dict_into_mbs(packed_data, mb_spec)
|
||||
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
|
||||
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
|
||||
|
||||
# assert microbatch split result does not violate requirements
|
||||
|
|
Loading…
Reference in New Issue