mirror of https://github.com/inclusionAI/AReaL
690 lines
22 KiB
Python
690 lines
22 KiB
Python
import os
|
|
import shutil
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from realhf.base.name_resolve import (
|
|
Etcd3NameRecordRepository,
|
|
NameEntryExistsError,
|
|
NameEntryNotFoundError,
|
|
NfsNameRecordRepository,
|
|
)
|
|
|
|
# Define backend configurations for parameterized tests
|
|
BACKENDS = [
|
|
("memory", {}),
|
|
("nfs", {}),
|
|
("ray", {}),
|
|
]
|
|
if os.environ.get("REAL_ETCD_ADDR"):
|
|
BACKENDS.append(
|
|
(
|
|
"etcd3",
|
|
{
|
|
"host": os.getenv("REAL_ETCD_ADDR").split(":")[0],
|
|
"port": int(os.getenv("REAL_ETCD_ADDR").split(":")[1]),
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
@pytest.fixture(params=BACKENDS, ids=[b[0] for b in BACKENDS])
|
|
def name_resolve(request):
|
|
"""Fixture that provides a name resolve repository for each backend type."""
|
|
backend_type, kwargs = request.param
|
|
|
|
# Special handling for NFS backend to use temp directory
|
|
if backend_type == "nfs":
|
|
temp_dir = tempfile.mkdtemp()
|
|
from realhf.base.name_resolve import NfsNameRecordRepository
|
|
|
|
original_root = NfsNameRecordRepository.RECORD_ROOT
|
|
NfsNameRecordRepository.RECORD_ROOT = temp_dir
|
|
repo = NfsNameRecordRepository()
|
|
yield repo
|
|
repo.reset()
|
|
NfsNameRecordRepository.RECORD_ROOT = original_root
|
|
shutil.rmtree(temp_dir)
|
|
elif backend_type == "memory":
|
|
from realhf.base.name_resolve import MemoryNameRecordRepository
|
|
|
|
repo = MemoryNameRecordRepository()
|
|
yield repo
|
|
repo.reset()
|
|
elif backend_type == "etcd3":
|
|
from realhf.base.name_resolve import Etcd3NameRecordRepository
|
|
|
|
repo = Etcd3NameRecordRepository(**kwargs)
|
|
yield repo
|
|
repo.reset()
|
|
elif backend_type == "ray":
|
|
from realhf.base.name_resolve import RayNameResolveRepository
|
|
|
|
repo = RayNameResolveRepository(**kwargs)
|
|
yield repo
|
|
repo.reset()
|
|
|
|
|
|
def test_basic_add_get(name_resolve):
|
|
"""Test basic add and get functionality."""
|
|
# Add a new entry
|
|
name_resolve.add("test_key", "test_value")
|
|
assert name_resolve.get("test_key") == "test_value"
|
|
|
|
# Test with non-string value (should be converted to string)
|
|
name_resolve.add("test_key_int", 123, replace=True)
|
|
assert name_resolve.get("test_key_int") == "123"
|
|
|
|
|
|
def test_add_with_replace(name_resolve):
|
|
"""Test add operation with replace flag."""
|
|
name_resolve.add("test_key", "initial_value")
|
|
|
|
# Should fail when replace=False
|
|
with pytest.raises(NameEntryExistsError):
|
|
name_resolve.add("test_key", "new_value", replace=False)
|
|
|
|
# Should succeed when replace=True
|
|
name_resolve.add("test_key", "new_value", replace=True)
|
|
assert name_resolve.get("test_key") == "new_value"
|
|
|
|
|
|
def test_delete(name_resolve):
|
|
"""Test delete operation."""
|
|
name_resolve.add("test_key", "test_value")
|
|
name_resolve.delete("test_key")
|
|
|
|
# Verify deletion
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.get("test_key")
|
|
|
|
# Deleting non-existent key should raise
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.delete("non_existent_key")
|
|
|
|
|
|
def test_clear_subtree(name_resolve):
|
|
"""Test clearing a subtree of keys."""
|
|
# Create a subtree of keys
|
|
name_resolve.add("test_root/key1", "value1")
|
|
name_resolve.add("test_root/key2", "value2")
|
|
name_resolve.add("test_root/sub/key3", "value3")
|
|
name_resolve.add("other_root/key", "value")
|
|
|
|
# Clear the subtree
|
|
name_resolve.clear_subtree("test_root")
|
|
|
|
# Verify subtree is gone
|
|
assert name_resolve.get_subtree("test_root") == []
|
|
assert name_resolve.find_subtree("test_root") == []
|
|
|
|
# Verify other tree remains
|
|
assert name_resolve.get("other_root/key") == "value"
|
|
|
|
|
|
def test_get_subtree(name_resolve):
|
|
"""Test retrieving values from a subtree."""
|
|
name_resolve.add("test_root/key1", "value1")
|
|
name_resolve.add("test_root/key2", "value2")
|
|
name_resolve.add("test_root/sub/key3", "value3")
|
|
|
|
values = name_resolve.get_subtree("test_root")
|
|
assert set(values) == {"value1", "value2", "value3"}
|
|
|
|
|
|
def test_find_subtree(name_resolve):
|
|
"""Test finding keys in a subtree."""
|
|
name_resolve.add("test_root/key1", "value1")
|
|
name_resolve.add("test_root/key2", "value2")
|
|
name_resolve.add("test_root/sub/key3", "value3")
|
|
|
|
keys = name_resolve.find_subtree("test_root")
|
|
assert set(keys) == {"test_root/key1", "test_root/key2", "test_root/sub/key3"}
|
|
assert keys == sorted(keys) # Should be sorted
|
|
|
|
|
|
def test_add_subentry(name_resolve):
|
|
"""Test adding subentries with automatic UUID generation."""
|
|
sub_name = name_resolve.add_subentry("test_root", "sub_value")
|
|
assert sub_name.startswith("test_root/")
|
|
assert len(sub_name.split("/")[-1]) == 8 # UUID part should be 8 chars
|
|
assert name_resolve.get(sub_name) == "sub_value"
|
|
|
|
|
|
def test_wait(name_resolve):
|
|
"""Test waiting for a key to appear."""
|
|
|
|
def delayed_add():
|
|
time.sleep(0.1)
|
|
name_resolve.add("test_key", "test_value")
|
|
|
|
thread = threading.Thread(target=delayed_add, daemon=True)
|
|
thread.start()
|
|
|
|
# Should return once key is added
|
|
assert name_resolve.wait("test_key", timeout=2) == "test_value"
|
|
thread.join()
|
|
|
|
# Test timeout
|
|
with pytest.raises(TimeoutError):
|
|
name_resolve.wait("non_existent_key", timeout=0.1)
|
|
|
|
|
|
def test_watch_names(name_resolve):
|
|
"""Test watching keys for changes."""
|
|
callback_called = False
|
|
|
|
def callback():
|
|
nonlocal callback_called
|
|
callback_called = True
|
|
|
|
name_resolve.add("test_key", "test_value")
|
|
name_resolve.watch_names("test_key", callback, poll_frequency=0.1)
|
|
|
|
# Delete the key to trigger callback
|
|
time.sleep(0.1) # Ensure watcher is ready
|
|
name_resolve.delete("test_key")
|
|
|
|
# Wait for callback
|
|
time.sleep(2)
|
|
assert callback_called
|
|
|
|
|
|
def test_reset(name_resolve):
|
|
"""Test reset functionality (cleanup of delete_on_exit keys)."""
|
|
name_resolve.add("test_key1", "value1", delete_on_exit=True)
|
|
name_resolve.add("test_key_no_delete", "value2", delete_on_exit=False)
|
|
name_resolve.reset()
|
|
|
|
# Only delete_on_exit=True keys should be removed
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.get("test_key1")
|
|
assert name_resolve.get("test_key_no_delete") == "value2"
|
|
name_resolve.delete("test_key_no_delete")
|
|
|
|
|
|
def test_context_manager(name_resolve):
|
|
"""Test context manager functionality."""
|
|
with name_resolve.__class__() as repo:
|
|
repo.add("test_key", "test_value", delete_on_exit=True)
|
|
assert repo.get("test_key") == "test_value"
|
|
|
|
# Key should be deleted after context exits
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.get("test_key")
|
|
|
|
|
|
def test_concurrent_access(name_resolve):
|
|
"""Test concurrent access to the same key."""
|
|
name_resolve.add("test_key", "initial_value")
|
|
|
|
def modify_value():
|
|
for i in range(5):
|
|
current = name_resolve.get("test_key")
|
|
name_resolve.add(
|
|
"test_key", f"modified_{threading.get_ident()}_{i}", replace=True
|
|
)
|
|
time.sleep(0.01)
|
|
|
|
threads = [threading.Thread(target=modify_value) for _ in range(3)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Final value should be one of the modified values
|
|
final_value = name_resolve.get("test_key")
|
|
assert "modified_" in final_value
|
|
|
|
|
|
def test_path_normalization(name_resolve):
|
|
"""Test handling of different path formats."""
|
|
# Test paths with trailing slashes
|
|
name_resolve.add("test_path/", "value1")
|
|
assert name_resolve.get("test_path") == "value1"
|
|
# with pytest.raises(NameEntryNotFoundError):
|
|
assert name_resolve.get("test_path/") == "value1"
|
|
|
|
# Test paths with double slashes
|
|
name_resolve.add("test//path", "value2")
|
|
assert name_resolve.get("test//path") == "value2"
|
|
|
|
# Test relative paths
|
|
with pytest.raises(NameEntryExistsError):
|
|
name_resolve.add("./test_path", "value3")
|
|
name_resolve.add("./test_path", "value3", replace=True)
|
|
assert name_resolve.get("./test_path") == "value3"
|
|
|
|
|
|
def test_add_with_invalid_inputs(name_resolve):
|
|
"""Test add method with invalid inputs."""
|
|
# Test with None name
|
|
with pytest.raises(
|
|
Exception
|
|
): # The specific exception type may vary by implementation
|
|
name_resolve.add(None, "value")
|
|
|
|
# Test with empty name
|
|
with pytest.raises(Exception):
|
|
name_resolve.add("", "value")
|
|
|
|
# Test with None value
|
|
name_resolve.add("test_key", None)
|
|
assert name_resolve.get("test_key") == "None"
|
|
|
|
|
|
def test_long_paths_and_values(name_resolve):
|
|
"""Test behavior with very long path names and values."""
|
|
long_name = "a/" * 100 + "key"
|
|
long_value = "x" * 10000
|
|
|
|
name_resolve.add(long_name, long_value)
|
|
assert name_resolve.get(long_name) == long_value
|
|
|
|
|
|
def test_special_characters(name_resolve):
|
|
"""Test handling of special characters in names and values."""
|
|
special_chars = "!@#$%^&*()_+-=[]{}|;:'\",.<>?`~ "
|
|
|
|
# Test special characters in name
|
|
for char in special_chars:
|
|
try:
|
|
name = f"test{char}key"
|
|
name_resolve.add(name, "value")
|
|
assert name_resolve.get(name) == "value"
|
|
name_resolve.delete(name)
|
|
except Exception as e:
|
|
print(f"Failed with character '{char}': {e}")
|
|
|
|
# Test special characters in value
|
|
for char in special_chars:
|
|
value = f"test{char}value"
|
|
name_resolve.add(f"key_{char}", value)
|
|
assert name_resolve.get(f"key_{char}") == value
|
|
|
|
|
|
def test_unicode_support(name_resolve):
|
|
"""Test support for Unicode characters in names and values."""
|
|
unicode_name = "测试/键"
|
|
unicode_value = "价值"
|
|
|
|
name_resolve.add(unicode_name, unicode_value)
|
|
assert name_resolve.get(unicode_name) == unicode_value
|
|
|
|
|
|
def test_stress_concurrent_add_get_delete(name_resolve):
|
|
"""Stress test with many concurrent operations."""
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
num_threads = 20
|
|
ops_per_thread = 50
|
|
|
|
# Track success/failure counts
|
|
results = {
|
|
"success": 0,
|
|
"failures": 0,
|
|
}
|
|
|
|
def worker(thread_id):
|
|
try:
|
|
for i in range(ops_per_thread):
|
|
key = f"concurrent_key_{thread_id}_{i}"
|
|
value = f"value_{thread_id}_{i}"
|
|
|
|
# Add the key
|
|
name_resolve.add(key, value)
|
|
|
|
# Get and verify the key
|
|
retrieved = name_resolve.get(key)
|
|
assert retrieved == value
|
|
|
|
# Delete the key
|
|
name_resolve.delete(key)
|
|
|
|
# Verify deletion
|
|
try:
|
|
name_resolve.get(key)
|
|
results["failures"] += 1
|
|
except NameEntryNotFoundError:
|
|
results["success"] += 1
|
|
except Exception as e:
|
|
print(f"Thread {thread_id} failed: {e}")
|
|
results["failures"] += 1
|
|
|
|
# Run worker threads
|
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
futures = [executor.submit(worker, i) for i in range(num_threads)]
|
|
for future in futures:
|
|
future.result()
|
|
|
|
# Verify most operations succeeded
|
|
assert (
|
|
results["failures"] <= results["success"] * 0.1
|
|
) # Allow up to 10% failure rate
|
|
|
|
|
|
def test_add_subentry_uniqueness(name_resolve):
|
|
"""Test that add_subentry generates unique names."""
|
|
# Add multiple subentries to the same root
|
|
num_entries = 100
|
|
entries = set()
|
|
|
|
for _ in range(num_entries):
|
|
sub_name = name_resolve.add_subentry("test_root", "value")
|
|
entries.add(sub_name)
|
|
|
|
# Verify all entries are unique
|
|
assert len(entries) == num_entries
|
|
|
|
|
|
def test_wait_with_concurrent_delete(name_resolve):
|
|
"""Test wait behavior when a key is added and then deleted before wait completes."""
|
|
|
|
def add_then_delete():
|
|
time.sleep(0.1)
|
|
name_resolve.add("test_wait_key", "test_value")
|
|
time.sleep(1.0)
|
|
name_resolve.delete("test_wait_key")
|
|
|
|
thread = threading.Thread(target=add_then_delete, daemon=True)
|
|
thread.start()
|
|
|
|
# Wait with a timeout long enough to capture the key
|
|
value = name_resolve.wait("test_wait_key", timeout=3.0, poll_frequency=0.05)
|
|
assert value == "test_value"
|
|
|
|
# Wait for the thread to complete
|
|
thread.join()
|
|
time.sleep(0.5)
|
|
|
|
# Verify the key was deleted
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.get("test_wait_key")
|
|
|
|
|
|
def test_wait_edge_cases(name_resolve):
|
|
"""Test edge cases for the wait method."""
|
|
# Test with invalid timeout values
|
|
with pytest.raises(TimeoutError):
|
|
name_resolve.wait("nonexistent_key", timeout=0)
|
|
|
|
# Test with negative timeout (should behave like timeout=None)
|
|
with pytest.raises(TimeoutError):
|
|
name_resolve.wait("nonexistent_key", timeout=-1, poll_frequency=0.01)
|
|
|
|
# Test with very small poll frequency
|
|
with pytest.raises(TimeoutError):
|
|
name_resolve.wait("nonexistent_key", timeout=0.1, poll_frequency=0.001)
|
|
|
|
|
|
def test_watch_names_multiple_keys(name_resolve):
|
|
"""Test watching multiple keys."""
|
|
callback_count = 0
|
|
|
|
def callback():
|
|
nonlocal callback_count
|
|
callback_count += 1
|
|
|
|
# Add test keys
|
|
name_resolve.add("watch_key1", "value1")
|
|
name_resolve.add("watch_key2", "value2")
|
|
|
|
# Watch both keys
|
|
name_resolve.watch_names(["watch_key1", "watch_key2"], callback, poll_frequency=0.1)
|
|
|
|
# Delete one key
|
|
time.sleep(0.2) # Ensure watcher is ready
|
|
name_resolve.delete("watch_key1")
|
|
|
|
# Wait for callback
|
|
time.sleep(0.5)
|
|
|
|
# Delete second key
|
|
name_resolve.delete("watch_key2")
|
|
|
|
# Wait for callback
|
|
time.sleep(1.0)
|
|
|
|
# Callback should have been called exactly once (when the last key is deleted)
|
|
assert callback_count == 1
|
|
|
|
|
|
def test_thread_safety_of_watch_thread_run(name_resolve):
|
|
"""Test thread safety of _watch_thread_run."""
|
|
# Mock the get method to simulate race conditions
|
|
original_get = name_resolve.get
|
|
|
|
def mock_get(name):
|
|
# First call returns normally, second call raises exception
|
|
mock_get.counter += 1
|
|
if mock_get.counter % 2 == 0:
|
|
raise NameEntryNotFoundError(f"Key not found: {name}")
|
|
return original_get(name)
|
|
|
|
mock_get.counter = 0
|
|
|
|
# Create a callback function that tracks calls
|
|
callback_called = False
|
|
|
|
def callback():
|
|
nonlocal callback_called
|
|
callback_called = True
|
|
|
|
# Add a test key
|
|
name_resolve.add("test_key", "test_value")
|
|
|
|
# Patch the get method
|
|
with patch.object(name_resolve, "get", side_effect=mock_get):
|
|
# Call _watch_thread_run directly
|
|
name_resolve._watch_thread_run("test_key", callback, 0.1, 1)
|
|
|
|
# Verify callback was called
|
|
assert callback_called
|
|
|
|
|
|
def test_keepalive_ttl(name_resolve):
|
|
"""Test keepalive_ttl functionality."""
|
|
# Skip if not Etcd3NameRecordRepository, as TTL might only be fully supported there
|
|
if "Etcd3NameRecordRepository" not in name_resolve.__class__.__name__:
|
|
pytest.skip("keepalive_ttl test is specific to Etcd3NameRecordRepository")
|
|
|
|
# Add a key with short TTL
|
|
name_resolve.add("ttl_key", "ttl_value", keepalive_ttl=2)
|
|
|
|
# Wait for less than the TTL - key should still exist
|
|
time.sleep(1)
|
|
assert name_resolve.get("ttl_key") == "ttl_value"
|
|
|
|
# Mock the keep-alive mechanism to simulate failure
|
|
with patch.object(
|
|
name_resolve._client, "refresh_lease", side_effect=Exception("Refresh failed")
|
|
):
|
|
# Wait longer than the TTL
|
|
time.sleep(3)
|
|
|
|
# Key should be gone if TTL is working
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.get("ttl_key")
|
|
|
|
|
|
def test_subentry_with_custom_uuid(name_resolve, monkeypatch):
|
|
"""Test add_subentry with a predictable UUID for deterministic testing."""
|
|
# Mock uuid.uuid4 to return a predictable value
|
|
mock_uuid = MagicMock()
|
|
mock_uuid.return_value = "12345678-1234-5678-1234-567812345678"
|
|
monkeypatch.setattr(uuid, "uuid4", mock_uuid)
|
|
|
|
# Add a subentry
|
|
sub_name = name_resolve.add_subentry("test_root", "sub_value")
|
|
|
|
# Verify the subentry has the expected name
|
|
assert sub_name == "test_root/12345678"
|
|
assert name_resolve.get(sub_name) == "sub_value"
|
|
|
|
|
|
def test_race_condition_in_add(name_resolve):
|
|
"""Test race condition when adding the same key concurrently."""
|
|
if isinstance(name_resolve, NfsNameRecordRepository):
|
|
pytest.skip("NFS repo cannot tackle race conditions")
|
|
|
|
# Define the number of concurrent threads
|
|
num_threads = 10
|
|
key = "race_condition_key"
|
|
success_count = 0
|
|
failure_count = 0
|
|
|
|
def add_with_same_key():
|
|
nonlocal success_count, failure_count
|
|
try:
|
|
name_resolve.add(key, f"value_{threading.get_ident()}", replace=False)
|
|
success_count += 1
|
|
except NameEntryExistsError:
|
|
failure_count += 1
|
|
|
|
# Run concurrent add operations
|
|
threads = [threading.Thread(target=add_with_same_key) for _ in range(num_threads)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Verify only one thread succeeded
|
|
assert success_count == 1
|
|
assert failure_count == num_threads - 1
|
|
|
|
# Verify the key exists
|
|
assert name_resolve.get(key) is not None
|
|
|
|
|
|
def test_find_subtree_with_empty_result(name_resolve):
|
|
"""Test find_subtree behavior when no matching keys are found."""
|
|
# Ensure no keys exist with this prefix
|
|
prefix = "nonexistent_prefix"
|
|
|
|
# Call find_subtree
|
|
result = name_resolve.find_subtree(prefix)
|
|
|
|
# Verify result is an empty list, not None
|
|
assert result == []
|
|
assert isinstance(result, list)
|
|
|
|
|
|
def test_get_subtree_with_empty_result(name_resolve):
|
|
"""Test get_subtree behavior when no matching keys are found."""
|
|
# Ensure no keys exist with this prefix
|
|
prefix = "nonexistent_prefix"
|
|
|
|
# Call get_subtree
|
|
result = name_resolve.get_subtree(prefix)
|
|
|
|
# Verify result is an empty list, not None
|
|
assert result == []
|
|
assert isinstance(result, list)
|
|
|
|
|
|
def test_clear_subtree_with_nonexistent_prefix(name_resolve):
|
|
"""Test clear_subtree behavior with a nonexistent prefix."""
|
|
# Ensure no keys exist with this prefix
|
|
prefix = "nonexistent_prefix"
|
|
|
|
# Call clear_subtree - should not raise exception
|
|
name_resolve.clear_subtree(prefix)
|
|
|
|
# Add a key elsewhere and verify it's not affected
|
|
name_resolve.add("test_key", "test_value")
|
|
assert name_resolve.get("test_key") == "test_value"
|
|
|
|
|
|
def test_nested_subtrees(name_resolve):
|
|
"""Test behavior with deeply nested subtrees."""
|
|
# Create a deeply nested subtree
|
|
name_resolve.add("root/level1/level2/level3/key1", "value1")
|
|
name_resolve.add("root/level1/level2/key2", "value2")
|
|
name_resolve.add("root/level1/key3", "value3")
|
|
|
|
# Test get_subtree at different levels
|
|
assert set(name_resolve.get_subtree("root")) == {"value1", "value2", "value3"}
|
|
assert set(name_resolve.get_subtree("root/level1/level2")) == {"value1", "value2"}
|
|
|
|
# Test find_subtree at different levels
|
|
assert set(name_resolve.find_subtree("root/level1")) == {
|
|
"root/level1/level2/level3/key1",
|
|
"root/level1/level2/key2",
|
|
"root/level1/key3",
|
|
}
|
|
|
|
# Clear a subtree
|
|
name_resolve.clear_subtree("root/level1/level2")
|
|
|
|
# Verify only the specified subtree was cleared
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.get("root/level1/level2/level3/key1")
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.get("root/level1/level2/key2")
|
|
assert name_resolve.get("root/level1/key3") == "value3"
|
|
|
|
|
|
def test_corner_case_get_same_as_prefix(name_resolve):
|
|
"""Test get behavior when a key is both a prefix and a value."""
|
|
# Add entries
|
|
name_resolve.add("prefix", "parent_value")
|
|
name_resolve.add("prefix/child", "child_value")
|
|
|
|
# Verify both keys can be retrieved individually
|
|
assert name_resolve.get("prefix") == "parent_value"
|
|
assert name_resolve.get("prefix/child") == "child_value"
|
|
|
|
# Verify get_subtree includes both values
|
|
values = name_resolve.get_subtree("prefix")
|
|
assert set(values) == {"parent_value", "child_value"}
|
|
|
|
# Verify find_subtree includes both keys
|
|
keys = name_resolve.find_subtree("prefix")
|
|
assert set(keys) == {"prefix", "prefix/child"}
|
|
|
|
|
|
@pytest.mark.skipif(os.getenv("REAL_ETCD_ADDR") is None, reason="ETCD3 not configured")
|
|
def test_etcd3_specific_features(name_resolve):
|
|
if not isinstance(name_resolve, Etcd3NameRecordRepository):
|
|
pytest.skip("ETCD3 specific test")
|
|
# Test the keepalive thread
|
|
name_resolve.add("test_key", "test_value", keepalive_ttl=2)
|
|
time.sleep(1) # Wait for the keepalive thread to refresh the lease
|
|
# Ensure the key still exists
|
|
value, _ = name_resolve._client.get("test_key")
|
|
assert value.decode("utf-8") == "test_value"
|
|
time.sleep(2) # Wait for the lease to expire
|
|
with pytest.raises(NameEntryNotFoundError):
|
|
name_resolve.get("test_key")
|
|
|
|
|
|
@pytest.mark.skipif(os.getenv("REAL_ETCD_ADDR") is not None, reason="NFS specific test")
|
|
def test_nfs_specific_features(name_resolve):
|
|
"""Test features specific to NFS backend."""
|
|
from realhf.base.name_resolve import NfsNameRecordRepository
|
|
|
|
if not isinstance(name_resolve, NfsNameRecordRepository):
|
|
pytest.skip("NFS specific test")
|
|
|
|
# Test handling of stale file handles
|
|
name_resolve.add("test_key", "test_value")
|
|
|
|
original_open = open
|
|
call_count = 0
|
|
|
|
def mock_open(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count <= 3: # Fail first 3 times
|
|
raise OSError(116, "Stale file handle")
|
|
return original_open(*args, **kwargs)
|
|
|
|
with patch("builtins.open", mock_open):
|
|
assert name_resolve.get("test_key") == "test_value"
|
|
assert call_count == 4
|