mirror of https://github.com/inclusionAI/AReaL
fix: remove the file lock of NFS-based name resolve and add testcases (#27)
* . * add nfs name resolve tests
This commit is contained in:
parent
62e51c3109
commit
5da0c62145
|
@ -271,24 +271,8 @@ class MemoryNameRecordRepository(NameRecordRepository):
|
|||
|
||||
class NfsNameRecordRepository(NameRecordRepository):
|
||||
RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/"
|
||||
LOCK_FILE = f"{cluster_spec.fileroot}/name_resolve/LOCK"
|
||||
os.makedirs(RECORD_ROOT, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def locked(fn: Callable) -> Callable:
|
||||
def fn_(*args, **kwargs):
|
||||
import fcntl
|
||||
|
||||
with open(NfsNameRecordRepository.LOCK_FILE, "w") as fd:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
try:
|
||||
res = fn(*args, **kwargs)
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
return res
|
||||
|
||||
return fn_
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__to_delete = set()
|
||||
|
||||
|
@ -300,7 +284,6 @@ class NfsNameRecordRepository(NameRecordRepository):
|
|||
def __file_path(name):
|
||||
return os.path.join(NfsNameRecordRepository.__dir_path(name), "ENTRY")
|
||||
|
||||
@locked
|
||||
def add(
|
||||
self,
|
||||
name,
|
||||
|
@ -309,6 +292,8 @@ class NfsNameRecordRepository(NameRecordRepository):
|
|||
keepalive_ttl=None,
|
||||
replace=False,
|
||||
):
|
||||
if not name:
|
||||
raise ValueError("Name cannot be empty")
|
||||
path = self.__file_path(name)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
if os.path.isfile(path) and not replace:
|
||||
|
@ -320,7 +305,6 @@ class NfsNameRecordRepository(NameRecordRepository):
|
|||
if delete_on_exit:
|
||||
self.__to_delete.add(name)
|
||||
|
||||
@locked
|
||||
def delete(self, name):
|
||||
path = self.__file_path(name)
|
||||
if not os.path.isfile(path):
|
||||
|
@ -344,7 +328,6 @@ class NfsNameRecordRepository(NameRecordRepository):
|
|||
else:
|
||||
logger.info("No such name resolve path: %s", dir_path)
|
||||
|
||||
@locked
|
||||
def get(self, name):
|
||||
path = self.__file_path(name)
|
||||
if not os.path.isfile(path):
|
||||
|
@ -377,7 +360,11 @@ class NfsNameRecordRepository(NameRecordRepository):
|
|||
rs = []
|
||||
if os.path.isdir(dir_path):
|
||||
for item in os.listdir(dir_path):
|
||||
rs.append(os.path.join(name_root, item))
|
||||
try:
|
||||
self.get(os.path.join(name_root, item))
|
||||
rs.append(os.path.join(name_root, item))
|
||||
except NameEntryNotFoundError:
|
||||
pass
|
||||
rs.sort()
|
||||
return rs
|
||||
|
||||
|
|
|
@ -0,0 +1,272 @@
|
|||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from realhf.base.name_resolve import (
|
||||
NameEntryExistsError,
|
||||
NameEntryNotFoundError,
|
||||
NfsNameRecordRepository,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_nfs_root():
|
||||
# Create a temporary directory to simulate NFS root
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
original_root = NfsNameRecordRepository.RECORD_ROOT
|
||||
NfsNameRecordRepository.RECORD_ROOT = temp_dir
|
||||
yield temp_dir
|
||||
# Cleanup
|
||||
NfsNameRecordRepository.RECORD_ROOT = original_root
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nfs_repo(temp_nfs_root):
|
||||
repo = NfsNameRecordRepository()
|
||||
yield repo
|
||||
repo.reset()
|
||||
|
||||
|
||||
def test_add_basic(nfs_repo):
|
||||
# Test basic add functionality
|
||||
nfs_repo.add("test_key", "test_value")
|
||||
assert nfs_repo.get("test_key") == "test_value"
|
||||
|
||||
# Verify file was created
|
||||
assert os.path.isfile(
|
||||
os.path.join(NfsNameRecordRepository.RECORD_ROOT, "test_key/ENTRY")
|
||||
)
|
||||
|
||||
# Non-string value
|
||||
nfs_repo.add(
|
||||
"test_key", 123, replace=True
|
||||
) # Should fail if non-string values aren't converted
|
||||
assert nfs_repo.get("test_key") == str(123)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
nfs_repo.add("", "value")
|
||||
|
||||
|
||||
def test_add_with_replace(nfs_repo):
|
||||
# Test add with replace=False (should raise)
|
||||
nfs_repo.add("test_key", "test_value")
|
||||
with pytest.raises(NameEntryExistsError):
|
||||
nfs_repo.add("test_key", "new_value", replace=False)
|
||||
|
||||
# Test add with replace=True
|
||||
nfs_repo.add("test_key", "new_value", replace=True)
|
||||
assert nfs_repo.get("test_key") == "new_value"
|
||||
|
||||
|
||||
def test_add_delete_on_exit(nfs_repo):
|
||||
# Test delete_on_exit flag
|
||||
nfs_repo.add("test_key1", "value1", delete_on_exit=True)
|
||||
nfs_repo.add("test_key2", "value2", delete_on_exit=False)
|
||||
|
||||
assert "test_key1" in nfs_repo._NfsNameRecordRepository__to_delete
|
||||
assert "test_key2" not in nfs_repo._NfsNameRecordRepository__to_delete
|
||||
|
||||
|
||||
def test_delete(nfs_repo):
|
||||
# Test deleting existing key
|
||||
nfs_repo.add("test_key", "test_value")
|
||||
nfs_repo.delete("test_key")
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
nfs_repo.get("test_key")
|
||||
|
||||
# Test deleting non-existent key
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
nfs_repo.delete("non_existent_key")
|
||||
|
||||
|
||||
def test_delete_cleanup_dirs(nfs_repo):
|
||||
# Test that empty parent directories are cleaned up
|
||||
nfs_repo.add("test/path/key", "value")
|
||||
assert os.path.isdir(os.path.join(NfsNameRecordRepository.RECORD_ROOT, "test/path"))
|
||||
|
||||
nfs_repo.delete("test/path/key")
|
||||
# Should clean up empty parent directories
|
||||
assert not os.path.exists(
|
||||
os.path.join(NfsNameRecordRepository.RECORD_ROOT, "test/path")
|
||||
)
|
||||
assert not os.path.exists(os.path.join(NfsNameRecordRepository.RECORD_ROOT, "test"))
|
||||
|
||||
|
||||
def test_clear_subtree(nfs_repo):
|
||||
# Test clearing a subtree
|
||||
nfs_repo.add("test_root/key1", "value1")
|
||||
nfs_repo.add("test_root/key2", "value2")
|
||||
nfs_repo.add("test_root/sub/key3", "value3")
|
||||
nfs_repo.add("other_root/key", "value")
|
||||
|
||||
nfs_repo.clear_subtree("test_root")
|
||||
|
||||
# Verify subtree is gone
|
||||
assert nfs_repo.get_subtree("test_root") == []
|
||||
assert nfs_repo.find_subtree("test_root") == []
|
||||
|
||||
# Verify other tree is intact
|
||||
assert nfs_repo.get("other_root/key") == "value"
|
||||
|
||||
|
||||
def test_get(nfs_repo):
|
||||
# Test getting existing key
|
||||
nfs_repo.add("test_key", "test_value")
|
||||
assert nfs_repo.get("test_key") == "test_value"
|
||||
|
||||
# Test getting non-existent key
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
nfs_repo.get("non_existent_key")
|
||||
|
||||
|
||||
def test_get_stale_file_handle_recovery(nfs_repo):
|
||||
# Test handling of stale file handles
|
||||
nfs_repo.add("test_key", "test_value")
|
||||
|
||||
# Mock os.open to raise OSError with errno 116 (ESTALE) first few times
|
||||
original_open = open
|
||||
|
||||
def mock_open(*args, **kwargs):
|
||||
mock_open.call_count += 1
|
||||
if mock_open.call_count <= 3: # Fail first 3 times
|
||||
raise OSError(116, "Stale file handle")
|
||||
return original_open(*args, **kwargs)
|
||||
|
||||
mock_open.call_count = 0
|
||||
|
||||
with patch("builtins.open", mock_open):
|
||||
assert nfs_repo.get("test_key") == "test_value"
|
||||
assert mock_open.call_count == 4
|
||||
|
||||
|
||||
def test_get_subtree(nfs_repo):
|
||||
# Test getting subtree values
|
||||
nfs_repo.add("test_root/key1", "value1")
|
||||
nfs_repo.add("test_root/key2", "value2")
|
||||
nfs_repo.add("test_root/sub/key3", "value3")
|
||||
|
||||
values = nfs_repo.get_subtree("test_root")
|
||||
assert set(values) == {"value1", "value2"}
|
||||
|
||||
|
||||
def test_find_subtree(nfs_repo):
|
||||
# Test finding subtree keys
|
||||
nfs_repo.add("test_root/key1", "value1")
|
||||
nfs_repo.add("test_root/key2", "value2")
|
||||
nfs_repo.add("test_root/sub/key3", "value3")
|
||||
|
||||
keys = nfs_repo.find_subtree("test_root")
|
||||
assert set(keys) == {"test_root/key1", "test_root/key2"}
|
||||
assert keys == sorted(keys) # Should be sorted
|
||||
|
||||
|
||||
def test_reset(nfs_repo):
|
||||
# Test reset functionality
|
||||
nfs_repo.add("test_key1", "value1", delete_on_exit=True)
|
||||
nfs_repo.add("test_key2", "value2", delete_on_exit=False)
|
||||
|
||||
nfs_repo.reset()
|
||||
|
||||
# Only test_key1 should be deleted
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
nfs_repo.get("test_key1")
|
||||
assert nfs_repo.get("test_key2") == "value2"
|
||||
|
||||
|
||||
def test_context_manager(nfs_repo):
|
||||
# Test context manager functionality
|
||||
with NfsNameRecordRepository() as repo:
|
||||
repo.add("test_key", "test_value", delete_on_exit=True)
|
||||
assert repo.get("test_key") == "test_value"
|
||||
|
||||
# Key should be deleted after context exits
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
nfs_repo.get("test_key")
|
||||
|
||||
|
||||
def test_destructor(nfs_repo):
|
||||
# Test destructor functionality
|
||||
repo = NfsNameRecordRepository()
|
||||
repo.add("test_key", "test_value", delete_on_exit=True)
|
||||
|
||||
# Simulate object destruction
|
||||
repo.__del__()
|
||||
|
||||
# Key should be deleted
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
nfs_repo.get("test_key")
|
||||
|
||||
|
||||
def test_add_subentry(nfs_repo):
|
||||
# Test subentry creation
|
||||
sub_name = nfs_repo.add_subentry("test_root", "sub_value")
|
||||
assert sub_name.startswith("test_root/")
|
||||
assert nfs_repo.get(sub_name) == "sub_value"
|
||||
|
||||
|
||||
def test_wait(nfs_repo):
|
||||
# Test wait functionality
|
||||
import threading
|
||||
|
||||
def delayed_add():
|
||||
time.sleep(0.1)
|
||||
nfs_repo.add("test_key", "test_value")
|
||||
|
||||
job = threading.Thread(target=delayed_add, daemon=True)
|
||||
job.start()
|
||||
|
||||
# Should return once key is added
|
||||
assert nfs_repo.wait("test_key", timeout=2) == "test_value"
|
||||
job.join()
|
||||
|
||||
# Test timeout
|
||||
with pytest.raises(TimeoutError):
|
||||
nfs_repo.wait("non_existent_key", timeout=0.1)
|
||||
|
||||
|
||||
def test_watch_names(nfs_repo):
|
||||
# Test watch functionality
|
||||
callback_called = False
|
||||
|
||||
def callback():
|
||||
nonlocal callback_called
|
||||
callback_called = True
|
||||
|
||||
nfs_repo.add("test_key", "test_value")
|
||||
nfs_repo.watch_names("test_key", callback)
|
||||
|
||||
# Delete the key to trigger callback
|
||||
nfs_repo.delete("test_key")
|
||||
|
||||
# Wait for callback
|
||||
time.sleep(5) # Give watcher thread time to execute
|
||||
assert callback_called
|
||||
|
||||
|
||||
def test_concurrent_access(nfs_repo):
|
||||
# Test concurrent access to the same key
|
||||
import threading
|
||||
|
||||
nfs_repo.add("test_key", "initial_value")
|
||||
|
||||
def modify_value():
|
||||
for i in range(10):
|
||||
current = nfs_repo.get("test_key")
|
||||
nfs_repo.add("test_key", f"modified_{i}", replace=True)
|
||||
time.sleep(0.01)
|
||||
|
||||
threads = [threading.Thread(target=modify_value) for _ in range(5)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Final value should be one of the modified values
|
||||
final_value = nfs_repo.get("test_key")
|
||||
assert final_value.startswith("modified_")
|
Loading…
Reference in New Issue