mirror of https://github.com/inclusionAI/AReaL
Merge updates from ant repository. (#29)
* fix: `self.tasks_ids` should also be filtered * PullRequest: 67 Update v0.2.0 Dockerfile Merge branch fw/v0.2.0-dockerfile of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/67 Signed-off-by: 温差 <xushusheng.xss@antgroup.com> * fw/v0.2.0-dockerfile * PullRequest: 66 Update v0.2.0 cover letter Merge branch fw/v0.2.0-readme of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/66 Signed-off-by: 温差 <xushusheng.xss@antgroup.com> * . * . * . * . * . * update thpt fig * update readme 20250329-20:16 * update * update tutorial * . * upload 7B zero and 32B sft config * PullRequest: 72 change the condition of using etcd Merge branch fw/fix-etcd of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/72 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * change the condition of using etcd * PullRequest: 60 Change the default SGLang parameters to avoid precision issues. Merge branch fw/fix-sglang of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/60 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * change vllm config * . * . * PullRequest: 73 Fix a setup issue when using ETCD Merge branch fw/fix-etcd of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/73 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix etcd * . * . * PullRequest: 75 Fix epoch counter before model function call execution. Merge branch fw/fix-epoch-counter of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/75 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 76 Update from opensource repository. Merge branch mzy/update-from-opensource of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/76 Signed-off-by: 博惟 <bowei.fw@antgroup.com> * . * . * . * . * . * update thpt fig * update readme 20250329-20:16 * update * update tutorial * fw/v0.2.0-dockerfile * . * V0.2.0 prerelease (#8) * V0.2.0 prerelease (#9) * Update README.md * Clean up CI (#11) * Update README.md citation (#15) * Xss/readme (#16) * Merge updates from ant repository. (#18) * update readme update readme update readme * update readme * . * . * PullRequest: 77 Fix epoch counter for saving models Merge branch fw/fix-epoch-counter of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/77 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * PullRequest: 78 Increase SGLang init timeout from 120 secs to 300 secs Merge branch fw/incr-sglang-init-timeout of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/78 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 79 Fix default timeout of ETCD name resolve entries. Merge branch mzy/fix-etcd-default-timeout of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/79?tab=diff Signed-off-by: 博惟 <bowei.fw@antgroup.com> * fix etcd default timeout * . * PullRequest: 81 Fix save-load backend states for Megatron v0.11 Merge branch fw/fix-megatron-v0.11-recover of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/81 Signed-off-by: 温差 <xushusheng.xss@antgroup.com> * fix megatron v0.11 recover * PullRequest: 82 Add a push-pull stream for IPC communication between the inference client and the model worker. Merge branch fw/ipc-stream of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/82?tab=diff Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix megatron v0.11 recover * . --------- Signed-off-by: 博惟 <bowei.fw@antgroup.com> Co-authored-by: wanghuaijie.whj <wanghuaijie.whj@antgroup.com> Co-authored-by: 博惟 <bowei.fw@antgroup.com> Co-authored-by: meijun <meijun.mei@antgroup.com> Co-authored-by: chucai.dzq <chucai.dzq@alibaba-inc.com>
This commit is contained in:
parent
aff05c2544
commit
62e51c3109
|
@ -369,7 +369,7 @@ class FinetuneSpec:
|
|||
return (
|
||||
self.dataset_size
|
||||
- version.global_step * self.train_batch_size % self.dataset_size
|
||||
) < self.train_batch_size
|
||||
) <= self.train_batch_size
|
||||
|
||||
def inc_version(self, version: StepInfo) -> StepInfo:
|
||||
if self.is_new_epoch(version):
|
||||
|
|
|
@ -621,7 +621,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
|
|||
name,
|
||||
value,
|
||||
delete_on_exit=True,
|
||||
keepalive_ttl=300,
|
||||
keepalive_ttl=None,
|
||||
replace=False,
|
||||
):
|
||||
"""Add a key-value pair to etcd with optional TTL.
|
||||
|
|
|
@ -678,7 +678,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
# Deleting models directly will not release the memory.
|
||||
# We must disable hooks at first.
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
model.module.module.engine.ddp.disable_forward_pre_hook()
|
||||
model.module.engine.ddp.disable_forward_pre_hook()
|
||||
else:
|
||||
optimizer = model.module.engine.optim
|
||||
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
|
||||
|
@ -688,6 +688,19 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
assert isinstance(model.module, ReaLMegatronEngine)
|
||||
optimizer = model.module.engine.optim
|
||||
param_state = optimizer.get_parameter_state_fs_bucket_space()
|
||||
assert isinstance(optimizer, DistributedOptimizer)
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
# Fix the keyerror: "padding"
|
||||
for gbuf_idx, gbuf_range_maps in enumerate(optimizer.gbuf_ranges):
|
||||
assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
|
||||
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
|
||||
for bucket_idx, gbuf_range_map in enumerate(
|
||||
gbuf_range_map_for_all_buckets
|
||||
):
|
||||
bucket_state = param_state[gbuf_idx][dtype][bucket_idx]
|
||||
for elem in bucket_state:
|
||||
elem["padding"] = False
|
||||
|
||||
sd = optimizer.state_dict()
|
||||
dp = constants.data_parallel_rank()
|
||||
pp = constants.pipe_parallel_rank()
|
||||
|
@ -715,7 +728,8 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
optimizer.load_state_dict(sd)
|
||||
|
||||
param_state = torch.load(
|
||||
pathlib.Path(load_dir) / f"megatron_optim_param_sd_d{dp}p{pp}t{tp}.mckpt"
|
||||
pathlib.Path(load_dir) / f"megatron_optim_param_sd_d{dp}p{pp}t{tp}.mckpt",
|
||||
weights_only=False,
|
||||
)
|
||||
optimizer.load_parameter_state_from_fs_bucket_space(param_state)
|
||||
|
||||
|
|
|
@ -42,6 +42,8 @@ from realhf.base import (
|
|||
|
||||
logger = logging.getLogger("SGLang backend")
|
||||
|
||||
SGLANG_INIT_TIMEOUT = 300
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str) -> str:
|
||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
|
@ -245,7 +247,7 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
from sglang.utils import get_exception_traceback
|
||||
|
||||
success = False
|
||||
for _ in range(120):
|
||||
for _ in range(SGLANG_INIT_TIMEOUT):
|
||||
await asyncio.sleep(1)
|
||||
try:
|
||||
res = requests.get(
|
||||
|
|
|
@ -374,9 +374,11 @@ class MasterWorker(worker_base.Worker):
|
|||
s = f"The next step is epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
|
||||
s += f"step {epoch_step}/{self._steps_per_epoch} "
|
||||
s += f"(global step {global_step}). "
|
||||
s += f"Should save a checkpoint for recover? {self.__rpc_ctrl.should_ckpt}. "
|
||||
s += f"Should save a persistent checkpoint for evaluation? {self.__rpc_ctrl.should_save}. "
|
||||
s += f"Should checkpoint? {self.__rpc_ctrl.should_ckpt}. "
|
||||
s += f"Should save? {self.__rpc_ctrl.should_save}. "
|
||||
s += f"Should run evaluation? {self.__rpc_ctrl.should_eval}. "
|
||||
s += f"Is the first step in epoch? {is_new_epoch}. "
|
||||
s += f"Is the last step in epoch? {is_epoch_last_step}. "
|
||||
self.logger.info(s)
|
||||
|
||||
# Traverse over the dataflow graph for once.
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
import logging
|
||||
from queue import Empty as QueueEmpty
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import orjson
|
||||
import zmq
|
||||
from zmq.utils.strtypes import asbytes
|
||||
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger("ZMQ Push-Pull Stream")
|
||||
|
||||
# Type alias for JSON-compatible objects
|
||||
JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None]
|
||||
|
||||
|
||||
class ZMQJsonPusher:
|
||||
"""
|
||||
JSON pusher using ZeroMQ.
|
||||
|
||||
Args:
|
||||
host: Host address (default: 'localhost')
|
||||
port: Port number (default: 5555)
|
||||
hwm: High-water mark for outgoing messages (default: 1000)
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 5555, hwm: int = 1000):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.ctx = zmq.Context.instance()
|
||||
self.socket = self.ctx.socket(zmq.PUSH)
|
||||
self.socket.setsockopt(zmq.SNDHWM, hwm)
|
||||
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||
|
||||
def push(self, data: JSONType) -> None:
|
||||
"""
|
||||
Push JSON-compatible data efficiently.
|
||||
|
||||
Args:
|
||||
data: JSON-serializable Python object
|
||||
|
||||
Raises:
|
||||
TypeError: If data is not JSON-serializable
|
||||
zmq.ZMQError: If ZeroMQ operation fails
|
||||
"""
|
||||
try:
|
||||
# Directly encode to bytes without intermediate string
|
||||
json_bytes = asbytes(orjson.dumps(data))
|
||||
self.socket.send(json_bytes, flags=zmq.NOBLOCK, copy=False)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(f"Data not JSON-serializable: {e}")
|
||||
except zmq.ZMQError as e:
|
||||
if e.errno == zmq.EAGAIN:
|
||||
logger.warning("Push operation would block (queue full)")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.socket.close(linger=0)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class ZMQJsonPuller:
|
||||
"""
|
||||
JSON puller using ZeroMQ with per-call timeout support in pull() method.
|
||||
|
||||
Args:
|
||||
host: Host address (default: 'localhost')
|
||||
port: Port number (default: 5555)
|
||||
default_timeout_ms: Default receive timeout in milliseconds (default: 1000)
|
||||
hwm: High-water mark for incoming messages (default: 1000)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 5555,
|
||||
default_timeout_ms: int = 1000,
|
||||
hwm: int = 1000,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.default_timeout_ms = default_timeout_ms
|
||||
|
||||
self.ctx = zmq.Context.instance()
|
||||
self.socket = self.ctx.socket(zmq.PULL)
|
||||
self.socket.setsockopt(zmq.RCVHWM, hwm)
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, self.default_timeout_ms)
|
||||
self.socket.bind(f"tcp://{self.host}:{self.port}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
|
||||
def pull(self, timeout_ms: Optional[int] = None) -> JSONType:
|
||||
"""
|
||||
Pull and decode JSON data with configurable timeout.
|
||||
|
||||
Args:
|
||||
timeout_ms: Optional timeout in seconds. If None, uses default_timeout_ms.
|
||||
|
||||
Returns:
|
||||
Deserialized JSON-compatible Python object
|
||||
|
||||
Raises:
|
||||
queue.Empty: If no message available within timeout
|
||||
"""
|
||||
current_timeout = self.default_timeout_ms if timeout_ms is None else timeout_ms
|
||||
events = dict(self.poller.poll(current_timeout))
|
||||
if self.socket in events:
|
||||
msg = self.socket.recv(flags=zmq.NOBLOCK, copy=False)
|
||||
return orjson.loads(msg.bytes.decode("utf-8"))
|
||||
raise QueueEmpty(f"No data available after {current_timeout}ms timeout")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.socket.close(linger=0)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
|
@ -56,4 +56,5 @@ aiohttp
|
|||
httpx>=0.28.1
|
||||
etcd3
|
||||
protobuf<3.21
|
||||
rich
|
||||
rich
|
||||
orjson>=3.10.16
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
|
||||
from realhf.api.core.model_api import FinetuneSpec, StepInfo
|
||||
|
||||
|
||||
@pytest.mark.parametrize("total_train_epochs", [10])
|
||||
@pytest.mark.parametrize("train_batch_size", [64, 25, 11, 38])
|
||||
@pytest.mark.parametrize("dataset_size", [200, 168, 77])
|
||||
def test_epoch_counter(
|
||||
total_train_epochs: int, train_batch_size: int, dataset_size: int
|
||||
):
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=total_train_epochs,
|
||||
train_batch_size=train_batch_size,
|
||||
dataset_size=dataset_size,
|
||||
)
|
||||
version = StepInfo()
|
||||
_epoch = 0
|
||||
_epoch_step = 0
|
||||
_step = 0
|
||||
is_new_epoch_records = []
|
||||
is_last_step_records = []
|
||||
gt = []
|
||||
while _epoch < total_train_epochs:
|
||||
is_last_step_records.append(ft_spec.is_epoch_last_step(version))
|
||||
is_new_epoch = ft_spec.is_new_epoch(version)
|
||||
is_new_epoch_records.append(is_new_epoch)
|
||||
|
||||
if is_new_epoch:
|
||||
version.epoch += 1
|
||||
version.epoch_step = 0
|
||||
|
||||
assert version.epoch == _epoch
|
||||
assert version.epoch_step == _epoch_step
|
||||
assert version.global_step == _step
|
||||
|
||||
version.epoch_step += 1
|
||||
version.global_step += 1
|
||||
|
||||
_step += 1
|
||||
_epoch_step += 1
|
||||
if _step * train_batch_size >= dataset_size * (_epoch + 1):
|
||||
_epoch += 1
|
||||
_epoch_step = 0
|
||||
gt.append(True)
|
||||
else:
|
||||
gt.append(False)
|
||||
assert gt == is_last_step_records
|
||||
assert [False] + is_last_step_records[:-1] == is_new_epoch_records
|
|
@ -0,0 +1,223 @@
|
|||
import threading
|
||||
import time
|
||||
from queue import Empty as QueueEmpty
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from realhf.system.push_pull_stream import ( # Replace with your actual module name
|
||||
ZMQJsonPuller,
|
||||
ZMQJsonPusher,
|
||||
)
|
||||
|
||||
# Constants for testing
|
||||
TEST_PORT = 5557 # Different from default to avoid conflicts
|
||||
TEST_HOST = "127.0.0.1"
|
||||
TIMEOUT = 1000 # ms
|
||||
|
||||
|
||||
# Fixtures for clean setup/teardown
|
||||
@pytest.fixture
|
||||
def puller():
|
||||
"""Fixture providing a puller instance"""
|
||||
with ZMQJsonPuller(host=TEST_HOST, port=TEST_PORT, default_timeout_ms=TIMEOUT) as p:
|
||||
yield p
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pusher():
|
||||
"""Fixture providing a pusher instance"""
|
||||
with ZMQJsonPusher(host=TEST_HOST, port=TEST_PORT) as p:
|
||||
yield p
|
||||
|
||||
|
||||
# Helper functions
|
||||
def run_pusher_in_thread(data: Any, count: int = 1, delay: float = 0.1):
|
||||
"""Helper to run pusher in a separate thread"""
|
||||
|
||||
def pusher_thread():
|
||||
with ZMQJsonPusher(host=TEST_HOST, port=TEST_PORT) as p:
|
||||
for _ in range(count):
|
||||
p.push(data)
|
||||
time.sleep(delay)
|
||||
|
||||
thread = threading.Thread(target=pusher_thread)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
## Test Cases
|
||||
|
||||
|
||||
def test_basic_push_pull(pusher, puller):
|
||||
"""Test basic push-pull functionality with simple data"""
|
||||
test_data = {"key": "value", "num": 42, "flag": True}
|
||||
|
||||
pusher.push(test_data)
|
||||
time.sleep(1)
|
||||
received = puller.pull()
|
||||
|
||||
assert received == test_data
|
||||
|
||||
|
||||
def test_empty_queue_raises(puller):
|
||||
"""Test that pulling from empty queue raises QueueEmpty"""
|
||||
with pytest.raises(QueueEmpty):
|
||||
puller.pull()
|
||||
|
||||
|
||||
def test_push_non_json_data(pusher):
|
||||
"""Test that pushing non-JSON data raises TypeError"""
|
||||
|
||||
class NonJSON:
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
pusher.push(NonJSON())
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
pusher.push(b"binary data")
|
||||
|
||||
|
||||
def test_push_large_data(pusher, puller):
|
||||
"""Test pushing and pulling large JSON data"""
|
||||
large_data = {
|
||||
"array": list(range(1000)),
|
||||
"nested": {"deep": [{"deeper": True} for _ in range(100)]},
|
||||
}
|
||||
|
||||
pusher.push(large_data)
|
||||
received = puller.pull()
|
||||
|
||||
assert received == large_data
|
||||
|
||||
|
||||
def test_pull_with_custom_timeout(puller):
|
||||
"""Test pull() with custom timeout parameter"""
|
||||
# Test with very short timeout (should raise immediately)
|
||||
start_time = time.time()
|
||||
with pytest.raises(QueueEmpty, match="No data available after 10ms timeout"):
|
||||
puller.pull(timeout_ms=10)
|
||||
elapsed = time.time() - start_time
|
||||
assert elapsed < 0.05 # Should be much less than 10ms
|
||||
|
||||
# Test with longer timeout while data is coming
|
||||
test_data = {"test": "timeout"}
|
||||
thread = run_pusher_in_thread(test_data, delay=0.2)
|
||||
|
||||
# Should get data within 500ms timeout
|
||||
start_time = time.time()
|
||||
received = puller.pull(timeout_ms=500)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert received == test_data
|
||||
assert elapsed < 0.5 # Should complete before timeout
|
||||
thread.join()
|
||||
|
||||
|
||||
def test_pull_timeout_none(puller):
|
||||
"""Test that pull(timeout_ms=None) uses default timeout"""
|
||||
start_time = time.time()
|
||||
with pytest.raises(QueueEmpty):
|
||||
puller.pull(timeout_ms=None) # Should use default 1000ms
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
assert abs(elapsed - 1.0) < 0.1 # Approximately default timeout
|
||||
|
||||
|
||||
def test_pull_timeout_zero(puller):
|
||||
"""Test that pull(timeout_ms=0) returns immediately"""
|
||||
start_time = time.time()
|
||||
with pytest.raises(QueueEmpty, match="No data available after 0ms timeout"):
|
||||
puller.pull(timeout_ms=0)
|
||||
elapsed = time.time() - start_time
|
||||
assert elapsed < 0.01 # Should be nearly instantaneous
|
||||
|
||||
|
||||
def test_mixed_timeout_usage(pusher, puller):
|
||||
"""Test mixing different timeout values in successive calls"""
|
||||
# First with default timeout (should wait)
|
||||
start_time = time.time()
|
||||
with pytest.raises(QueueEmpty):
|
||||
puller.pull()
|
||||
assert abs((time.time() - start_time) - 1.0) < 0.1
|
||||
|
||||
# Then with short timeout
|
||||
start_time = time.time()
|
||||
with pytest.raises(QueueEmpty):
|
||||
puller.pull(timeout_ms=100)
|
||||
assert abs((time.time() - start_time) - 0.1) < 0.05
|
||||
|
||||
# Then push some data
|
||||
pusher.push({"value": 42})
|
||||
|
||||
# Should get it even with very long timeout
|
||||
start_time = time.time()
|
||||
assert puller.pull(timeout_ms=5000) == {"value": 42}
|
||||
assert (time.time() - start_time) < 0.1 # Should return immediately
|
||||
|
||||
|
||||
def test_timeout_restoration_after_error(puller):
|
||||
"""Test that default timeout is restored after error"""
|
||||
original_timeout = puller.default_timeout_ms
|
||||
|
||||
# First verify default behavior
|
||||
with pytest.raises(QueueEmpty):
|
||||
puller.pull()
|
||||
|
||||
# Change timeout temporarily
|
||||
with pytest.raises(QueueEmpty):
|
||||
puller.pull(timeout_ms=100)
|
||||
|
||||
# Verify default timeout is restored
|
||||
start_time = time.time()
|
||||
with pytest.raises(QueueEmpty):
|
||||
puller.pull()
|
||||
elapsed = time.time() - start_time
|
||||
assert abs(elapsed - (original_timeout / 1000)) < 0.1
|
||||
|
||||
|
||||
def test_concurrent_access(puller):
|
||||
"""Test that puller can handle concurrent pushes"""
|
||||
test_data = {"message": "hello"}
|
||||
thread = run_pusher_in_thread(test_data)
|
||||
|
||||
received = puller.pull()
|
||||
thread.join()
|
||||
|
||||
assert received == test_data
|
||||
|
||||
|
||||
def test_multiple_messages(pusher, puller):
|
||||
"""Test sending and receiving multiple messages"""
|
||||
messages = [
|
||||
{"id": 1, "content": "first"},
|
||||
{"id": 2, "content": "second"},
|
||||
{"id": 3, "content": "third"},
|
||||
]
|
||||
|
||||
for msg in messages:
|
||||
pusher.push(msg)
|
||||
|
||||
for expected in messages:
|
||||
assert puller.pull() == expected
|
||||
|
||||
|
||||
def test_rapid_fire_messages(puller):
|
||||
"""Test handling of many rapid messages"""
|
||||
test_data = {"count": 0}
|
||||
thread = run_pusher_in_thread(test_data, count=100, delay=0.01)
|
||||
|
||||
received_count = 0
|
||||
while True:
|
||||
try:
|
||||
puller.pull()
|
||||
received_count += 1
|
||||
if received_count >= 100:
|
||||
break
|
||||
except QueueEmpty:
|
||||
if not thread.is_alive():
|
||||
break
|
||||
|
||||
thread.join()
|
||||
assert received_count == 100
|
Loading…
Reference in New Issue