Merge updates from ant repository. (#29)

* fix: `self.tasks_ids` should also be filtered

* PullRequest: 67 Update v0.2.0 Dockerfile

Merge branch fw/v0.2.0-dockerfile of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/67

Signed-off-by: 温差 <xushusheng.xss@antgroup.com>


* fw/v0.2.0-dockerfile

* PullRequest: 66 Update v0.2.0 cover letter

Merge branch fw/v0.2.0-readme of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/66

Signed-off-by: 温差 <xushusheng.xss@antgroup.com>


* .
* .
* .
* .
* .
* update thpt fig
* update readme 20250329-20:16
* update
* update tutorial
* .

* upload 7B zero and 32B sft config

* PullRequest: 72 change the condition of using etcd

Merge branch fw/fix-etcd of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/72

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* change the condition of using etcd

* PullRequest: 60 Change the default SGLang parameters to avoid precision issues.

Merge branch fw/fix-sglang of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/60

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* change vllm config
* .
* .

* PullRequest: 73 Fix a setup issue when using ETCD

Merge branch fw/fix-etcd of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/73

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix etcd
* .
* .

* PullRequest: 75 Fix epoch counter before model function call execution.

Merge branch fw/fix-epoch-counter of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/75

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 76 Update from opensource repository.

Merge branch mzy/update-from-opensource of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/76

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* .
* .
* .
* .
* .
* update thpt fig
* update readme 20250329-20:16
* update
* update tutorial
* fw/v0.2.0-dockerfile
* .
* V0.2.0 prerelease (#8)
* V0.2.0 prerelease (#9)
* Update README.md
* Clean up CI (#11)
* Update README.md citation (#15)
* Xss/readme (#16)
* Merge updates from ant repository. (#18)

* update readme
update readme

update readme

* update readme

* .

* .

* PullRequest: 77 Fix epoch counter for saving models

Merge branch fw/fix-epoch-counter of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/77

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .

* PullRequest: 78 Increase SGLang init timeout from 120 secs to 300 secs

Merge branch fw/incr-sglang-init-timeout of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/78

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 79 Fix default timeout of ETCD name resolve entries.

Merge branch mzy/fix-etcd-default-timeout of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/79?tab=diff

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* fix etcd default timeout
* .

* PullRequest: 81 Fix save-load backend states for Megatron v0.11

Merge branch fw/fix-megatron-v0.11-recover of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/81

Signed-off-by: 温差 <xushusheng.xss@antgroup.com>


* fix megatron v0.11 recover

* PullRequest: 82 Add a push-pull stream for IPC communication between the inference client and the model worker.

Merge branch fw/ipc-stream of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/82?tab=diff

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix megatron v0.11 recover
* .

---------

Signed-off-by: 博惟 <bowei.fw@antgroup.com>
Co-authored-by: wanghuaijie.whj <wanghuaijie.whj@antgroup.com>
Co-authored-by: 博惟 <bowei.fw@antgroup.com>
Co-authored-by: meijun <meijun.mei@antgroup.com>
Co-authored-by: chucai.dzq <chucai.dzq@alibaba-inc.com>
This commit is contained in:
nuzant 2025-04-07 21:49:34 +08:00 committed by GitHub
parent aff05c2544
commit 62e51c3109
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 427 additions and 8 deletions

View File

@ -369,7 +369,7 @@ class FinetuneSpec:
return (
self.dataset_size
- version.global_step * self.train_batch_size % self.dataset_size
) < self.train_batch_size
) <= self.train_batch_size
def inc_version(self, version: StepInfo) -> StepInfo:
if self.is_new_epoch(version):

View File

@ -621,7 +621,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
name,
value,
delete_on_exit=True,
keepalive_ttl=300,
keepalive_ttl=None,
replace=False,
):
"""Add a key-value pair to etcd with optional TTL.

View File

@ -678,7 +678,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
# Deleting models directly will not release the memory.
# We must disable hooks at first.
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
model.module.module.engine.ddp.disable_forward_pre_hook()
model.module.engine.ddp.disable_forward_pre_hook()
else:
optimizer = model.module.engine.optim
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
@ -688,6 +688,19 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
assert isinstance(model.module, ReaLMegatronEngine)
optimizer = model.module.engine.optim
param_state = optimizer.get_parameter_state_fs_bucket_space()
assert isinstance(optimizer, DistributedOptimizer)
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
# Fix the keyerror: "padding"
for gbuf_idx, gbuf_range_maps in enumerate(optimizer.gbuf_ranges):
assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
for bucket_idx, gbuf_range_map in enumerate(
gbuf_range_map_for_all_buckets
):
bucket_state = param_state[gbuf_idx][dtype][bucket_idx]
for elem in bucket_state:
elem["padding"] = False
sd = optimizer.state_dict()
dp = constants.data_parallel_rank()
pp = constants.pipe_parallel_rank()
@ -715,7 +728,8 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
optimizer.load_state_dict(sd)
param_state = torch.load(
pathlib.Path(load_dir) / f"megatron_optim_param_sd_d{dp}p{pp}t{tp}.mckpt"
pathlib.Path(load_dir) / f"megatron_optim_param_sd_d{dp}p{pp}t{tp}.mckpt",
weights_only=False,
)
optimizer.load_parameter_state_from_fs_bucket_space(param_state)

View File

@ -42,6 +42,8 @@ from realhf.base import (
logger = logging.getLogger("SGLang backend")
SGLANG_INIT_TIMEOUT = 300
def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
@ -245,7 +247,7 @@ class SGLangGenerationEngine(PipelinableEngine):
from sglang.utils import get_exception_traceback
success = False
for _ in range(120):
for _ in range(SGLANG_INIT_TIMEOUT):
await asyncio.sleep(1)
try:
res = requests.get(

View File

@ -374,9 +374,11 @@ class MasterWorker(worker_base.Worker):
s = f"The next step is epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
s += f"step {epoch_step}/{self._steps_per_epoch} "
s += f"(global step {global_step}). "
s += f"Should save a checkpoint for recover? {self.__rpc_ctrl.should_ckpt}. "
s += f"Should save a persistent checkpoint for evaluation? {self.__rpc_ctrl.should_save}. "
s += f"Should checkpoint? {self.__rpc_ctrl.should_ckpt}. "
s += f"Should save? {self.__rpc_ctrl.should_save}. "
s += f"Should run evaluation? {self.__rpc_ctrl.should_eval}. "
s += f"Is the first step in epoch? {is_new_epoch}. "
s += f"Is the last step in epoch? {is_epoch_last_step}. "
self.logger.info(s)
# Traverse over the dataflow graph for once.

View File

@ -0,0 +1,128 @@
import logging
from queue import Empty as QueueEmpty
from typing import Any, Dict, List, Optional, Union
import orjson
import zmq
from zmq.utils.strtypes import asbytes
from realhf.base import logging
logger = logging.getLogger("ZMQ Push-Pull Stream")
# Type alias for JSON-compatible objects
JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None]
class ZMQJsonPusher:
"""
JSON pusher using ZeroMQ.
Args:
host: Host address (default: 'localhost')
port: Port number (default: 5555)
hwm: High-water mark for outgoing messages (default: 1000)
"""
def __init__(self, host: str = "localhost", port: int = 5555, hwm: int = 1000):
self.host = host
self.port = port
self.ctx = zmq.Context.instance()
self.socket = self.ctx.socket(zmq.PUSH)
self.socket.setsockopt(zmq.SNDHWM, hwm)
self.socket.connect(f"tcp://{self.host}:{self.port}")
def push(self, data: JSONType) -> None:
"""
Push JSON-compatible data efficiently.
Args:
data: JSON-serializable Python object
Raises:
TypeError: If data is not JSON-serializable
zmq.ZMQError: If ZeroMQ operation fails
"""
try:
# Directly encode to bytes without intermediate string
json_bytes = asbytes(orjson.dumps(data))
self.socket.send(json_bytes, flags=zmq.NOBLOCK, copy=False)
except (TypeError, ValueError) as e:
raise TypeError(f"Data not JSON-serializable: {e}")
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN:
logger.warning("Push operation would block (queue full)")
raise
def close(self) -> None:
"""Clean up resources."""
self.socket.close(linger=0)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class ZMQJsonPuller:
"""
JSON puller using ZeroMQ with per-call timeout support in pull() method.
Args:
host: Host address (default: 'localhost')
port: Port number (default: 5555)
default_timeout_ms: Default receive timeout in milliseconds (default: 1000)
hwm: High-water mark for incoming messages (default: 1000)
"""
def __init__(
self,
host: str = "localhost",
port: int = 5555,
default_timeout_ms: int = 1000,
hwm: int = 1000,
):
self.host = host
self.port = port
self.default_timeout_ms = default_timeout_ms
self.ctx = zmq.Context.instance()
self.socket = self.ctx.socket(zmq.PULL)
self.socket.setsockopt(zmq.RCVHWM, hwm)
self.socket.setsockopt(zmq.RCVTIMEO, self.default_timeout_ms)
self.socket.bind(f"tcp://{self.host}:{self.port}")
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
def pull(self, timeout_ms: Optional[int] = None) -> JSONType:
"""
Pull and decode JSON data with configurable timeout.
Args:
timeout_ms: Optional timeout in seconds. If None, uses default_timeout_ms.
Returns:
Deserialized JSON-compatible Python object
Raises:
queue.Empty: If no message available within timeout
"""
current_timeout = self.default_timeout_ms if timeout_ms is None else timeout_ms
events = dict(self.poller.poll(current_timeout))
if self.socket in events:
msg = self.socket.recv(flags=zmq.NOBLOCK, copy=False)
return orjson.loads(msg.bytes.decode("utf-8"))
raise QueueEmpty(f"No data available after {current_timeout}ms timeout")
def close(self) -> None:
"""Clean up resources."""
self.socket.close(linger=0)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

View File

@ -56,4 +56,5 @@ aiohttp
httpx>=0.28.1
etcd3
protobuf<3.21
rich
rich
orjson>=3.10.16

View File

@ -0,0 +1,49 @@
import pytest
from realhf.api.core.model_api import FinetuneSpec, StepInfo
@pytest.mark.parametrize("total_train_epochs", [10])
@pytest.mark.parametrize("train_batch_size", [64, 25, 11, 38])
@pytest.mark.parametrize("dataset_size", [200, 168, 77])
def test_epoch_counter(
total_train_epochs: int, train_batch_size: int, dataset_size: int
):
ft_spec = FinetuneSpec(
total_train_epochs=total_train_epochs,
train_batch_size=train_batch_size,
dataset_size=dataset_size,
)
version = StepInfo()
_epoch = 0
_epoch_step = 0
_step = 0
is_new_epoch_records = []
is_last_step_records = []
gt = []
while _epoch < total_train_epochs:
is_last_step_records.append(ft_spec.is_epoch_last_step(version))
is_new_epoch = ft_spec.is_new_epoch(version)
is_new_epoch_records.append(is_new_epoch)
if is_new_epoch:
version.epoch += 1
version.epoch_step = 0
assert version.epoch == _epoch
assert version.epoch_step == _epoch_step
assert version.global_step == _step
version.epoch_step += 1
version.global_step += 1
_step += 1
_epoch_step += 1
if _step * train_batch_size >= dataset_size * (_epoch + 1):
_epoch += 1
_epoch_step = 0
gt.append(True)
else:
gt.append(False)
assert gt == is_last_step_records
assert [False] + is_last_step_records[:-1] == is_new_epoch_records

View File

@ -0,0 +1,223 @@
import threading
import time
from queue import Empty as QueueEmpty
from typing import Any
import pytest
from realhf.system.push_pull_stream import ( # Replace with your actual module name
ZMQJsonPuller,
ZMQJsonPusher,
)
# Constants for testing
TEST_PORT = 5557 # Different from default to avoid conflicts
TEST_HOST = "127.0.0.1"
TIMEOUT = 1000 # ms
# Fixtures for clean setup/teardown
@pytest.fixture
def puller():
"""Fixture providing a puller instance"""
with ZMQJsonPuller(host=TEST_HOST, port=TEST_PORT, default_timeout_ms=TIMEOUT) as p:
yield p
@pytest.fixture
def pusher():
"""Fixture providing a pusher instance"""
with ZMQJsonPusher(host=TEST_HOST, port=TEST_PORT) as p:
yield p
# Helper functions
def run_pusher_in_thread(data: Any, count: int = 1, delay: float = 0.1):
"""Helper to run pusher in a separate thread"""
def pusher_thread():
with ZMQJsonPusher(host=TEST_HOST, port=TEST_PORT) as p:
for _ in range(count):
p.push(data)
time.sleep(delay)
thread = threading.Thread(target=pusher_thread)
thread.start()
return thread
## Test Cases
def test_basic_push_pull(pusher, puller):
"""Test basic push-pull functionality with simple data"""
test_data = {"key": "value", "num": 42, "flag": True}
pusher.push(test_data)
time.sleep(1)
received = puller.pull()
assert received == test_data
def test_empty_queue_raises(puller):
"""Test that pulling from empty queue raises QueueEmpty"""
with pytest.raises(QueueEmpty):
puller.pull()
def test_push_non_json_data(pusher):
"""Test that pushing non-JSON data raises TypeError"""
class NonJSON:
pass
with pytest.raises(TypeError):
pusher.push(NonJSON())
with pytest.raises(TypeError):
pusher.push(b"binary data")
def test_push_large_data(pusher, puller):
"""Test pushing and pulling large JSON data"""
large_data = {
"array": list(range(1000)),
"nested": {"deep": [{"deeper": True} for _ in range(100)]},
}
pusher.push(large_data)
received = puller.pull()
assert received == large_data
def test_pull_with_custom_timeout(puller):
"""Test pull() with custom timeout parameter"""
# Test with very short timeout (should raise immediately)
start_time = time.time()
with pytest.raises(QueueEmpty, match="No data available after 10ms timeout"):
puller.pull(timeout_ms=10)
elapsed = time.time() - start_time
assert elapsed < 0.05 # Should be much less than 10ms
# Test with longer timeout while data is coming
test_data = {"test": "timeout"}
thread = run_pusher_in_thread(test_data, delay=0.2)
# Should get data within 500ms timeout
start_time = time.time()
received = puller.pull(timeout_ms=500)
elapsed = time.time() - start_time
assert received == test_data
assert elapsed < 0.5 # Should complete before timeout
thread.join()
def test_pull_timeout_none(puller):
"""Test that pull(timeout_ms=None) uses default timeout"""
start_time = time.time()
with pytest.raises(QueueEmpty):
puller.pull(timeout_ms=None) # Should use default 1000ms
elapsed = time.time() - start_time
assert abs(elapsed - 1.0) < 0.1 # Approximately default timeout
def test_pull_timeout_zero(puller):
"""Test that pull(timeout_ms=0) returns immediately"""
start_time = time.time()
with pytest.raises(QueueEmpty, match="No data available after 0ms timeout"):
puller.pull(timeout_ms=0)
elapsed = time.time() - start_time
assert elapsed < 0.01 # Should be nearly instantaneous
def test_mixed_timeout_usage(pusher, puller):
"""Test mixing different timeout values in successive calls"""
# First with default timeout (should wait)
start_time = time.time()
with pytest.raises(QueueEmpty):
puller.pull()
assert abs((time.time() - start_time) - 1.0) < 0.1
# Then with short timeout
start_time = time.time()
with pytest.raises(QueueEmpty):
puller.pull(timeout_ms=100)
assert abs((time.time() - start_time) - 0.1) < 0.05
# Then push some data
pusher.push({"value": 42})
# Should get it even with very long timeout
start_time = time.time()
assert puller.pull(timeout_ms=5000) == {"value": 42}
assert (time.time() - start_time) < 0.1 # Should return immediately
def test_timeout_restoration_after_error(puller):
"""Test that default timeout is restored after error"""
original_timeout = puller.default_timeout_ms
# First verify default behavior
with pytest.raises(QueueEmpty):
puller.pull()
# Change timeout temporarily
with pytest.raises(QueueEmpty):
puller.pull(timeout_ms=100)
# Verify default timeout is restored
start_time = time.time()
with pytest.raises(QueueEmpty):
puller.pull()
elapsed = time.time() - start_time
assert abs(elapsed - (original_timeout / 1000)) < 0.1
def test_concurrent_access(puller):
"""Test that puller can handle concurrent pushes"""
test_data = {"message": "hello"}
thread = run_pusher_in_thread(test_data)
received = puller.pull()
thread.join()
assert received == test_data
def test_multiple_messages(pusher, puller):
"""Test sending and receiving multiple messages"""
messages = [
{"id": 1, "content": "first"},
{"id": 2, "content": "second"},
{"id": 3, "content": "third"},
]
for msg in messages:
pusher.push(msg)
for expected in messages:
assert puller.pull() == expected
def test_rapid_fire_messages(puller):
"""Test handling of many rapid messages"""
test_data = {"count": 0}
thread = run_pusher_in_thread(test_data, count=100, delay=0.01)
received_count = 0
while True:
try:
puller.pull()
received_count += 1
if received_count >= 100:
break
except QueueEmpty:
if not thread.is_alive():
break
thread.join()
assert received_count == 100