mirror of https://github.com/inclusionAI/AReaL
PullRequest: 58 Support ETCD3 name resolving repo
Merge branch fw/etcd of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/58 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * .
This commit is contained in:
parent
6ccbb01ca8
commit
9f77f96580
|
@ -163,6 +163,7 @@ def main_start(args, recover_count: int = 0):
|
|||
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
|
||||
)
|
||||
for k, v in BASE_ENVIRONS.items():
|
||||
os.environ[k] = v
|
||||
|
|
|
@ -62,7 +62,7 @@ def reveal_pg_identity(expr_name, trial_name, worker_index):
|
|||
master_group_name = names.distributed_peer(
|
||||
expr_name, trial_name, GLOBAL_PROCESS_GROUP_NAME
|
||||
)
|
||||
name_resolve.add_subentry(master_group_name, str(worker_index), keepalive_ttl=30)
|
||||
name_resolve.add_subentry(master_group_name, str(worker_index), keepalive_ttl=300)
|
||||
|
||||
|
||||
def isolate_cuda_device(
|
||||
|
|
|
@ -4,18 +4,21 @@
|
|||
|
||||
# Implements a simple name resolving service, which can be considered as a distributed key-value dict.
|
||||
import dataclasses
|
||||
import getpass
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import shutil
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from realhf.base import logging, security, timeutil
|
||||
try:
|
||||
import etcd3
|
||||
except Exception:
|
||||
etcd3 = None
|
||||
|
||||
from realhf.base import cluster, logging, security, timeutil
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
|
||||
logger = logging.getLogger("name-resolve")
|
||||
|
@ -532,6 +535,353 @@ class RedisNameRecordRepository(NameRecordRepository):
|
|||
print("Testonly: dropped key:", name)
|
||||
|
||||
|
||||
class Etcd3NameRecordRepository(NameRecordRepository):
|
||||
"""Implements a name record repository using etcd3 as the backend storage.
|
||||
|
||||
This implementation provides distributed key-value storage with support for
|
||||
TTL-based expiration, atomic operations, and key watching functionality.
|
||||
"""
|
||||
|
||||
# Default configuration
|
||||
host, port = os.getenv("REAL_ETCD_ADDR", "localhost:2379").split(":")
|
||||
ETCD_HOST = host
|
||||
ETCD_PORT = int(port)
|
||||
ETCD_USER = None
|
||||
ETCD_PASSWORD = None
|
||||
KEEPALIVE_POLL_FREQUENCY = 1
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Entry:
|
||||
value: str
|
||||
lease_id: Optional[int] = None
|
||||
keepalive_ttl: Optional[int] = None
|
||||
keeper: Optional[timeutil.FrequencyControl] = None
|
||||
|
||||
def __init__(self, host=None, port=None, user=None, password=None, **kwargs):
|
||||
"""Initialize the etcd3 name record repository.
|
||||
|
||||
Args:
|
||||
host: etcd server host (defaults to ETCD_HOST)
|
||||
port: etcd server port (defaults to ETCD_PORT)
|
||||
user: etcd username for authentication (defaults to ETCD_USER)
|
||||
password: etcd password for authentication (defaults to ETCD_PASSWORD)
|
||||
**kwargs: Additional configuration parameters
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Set connection parameters
|
||||
self._host = host or self.ETCD_HOST
|
||||
self._port = port or self.ETCD_PORT
|
||||
self._user = user or self.ETCD_USER
|
||||
self._password = password or self.ETCD_PASSWORD
|
||||
|
||||
# Connect to etcd
|
||||
self._client = etcd3.client(
|
||||
host=self._host, port=self._port, user=self._user, password=self._password
|
||||
)
|
||||
|
||||
# Keep track of entries for cleanup and keepalive
|
||||
self._entries = {}
|
||||
self._keepalive_running = True
|
||||
self._keepalive_thread = threading.Thread(
|
||||
target=self._keepalive_thread_run, daemon=True
|
||||
)
|
||||
self._keepalive_thread.start()
|
||||
|
||||
logger.info(f"Connected to etcd3 at {self._host}:{self._port}")
|
||||
|
||||
def __del__(self):
|
||||
"""Clean up resources when the object is deleted."""
|
||||
self._keepalive_running = False
|
||||
if hasattr(self, "_keepalive_thread"):
|
||||
self._keepalive_thread.join(timeout=5)
|
||||
self.reset()
|
||||
if hasattr(self, "_client"):
|
||||
self._client.close()
|
||||
|
||||
def _create_lease(self, ttl_seconds):
|
||||
"""Create an etcd lease with the specified TTL.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live in seconds
|
||||
|
||||
Returns:
|
||||
The lease ID
|
||||
"""
|
||||
lease = self._client.lease(ttl_seconds)
|
||||
return lease.id
|
||||
|
||||
def add(
|
||||
self,
|
||||
name,
|
||||
value,
|
||||
delete_on_exit=True,
|
||||
keepalive_ttl=300,
|
||||
replace=False,
|
||||
):
|
||||
"""Add a key-value pair to etcd with optional TTL.
|
||||
|
||||
Args:
|
||||
name: Key name
|
||||
value: Value to store
|
||||
delete_on_exit: Whether to delete the key when this object is destroyed
|
||||
keepalive_ttl: TTL in seconds for the key (default: 10)
|
||||
replace: Whether to replace an existing key
|
||||
|
||||
Raises:
|
||||
NameEntryExistsError: If the key already exists and replace is False
|
||||
"""
|
||||
name = name.rstrip("/")
|
||||
value = str(value)
|
||||
|
||||
with self._lock:
|
||||
# Check if key exists when replace=False
|
||||
if not replace:
|
||||
existing_value, _ = self._client.get(name)
|
||||
if existing_value is not None:
|
||||
raise NameEntryExistsError(
|
||||
f"Key already exists: K={name} V={existing_value.decode()}"
|
||||
)
|
||||
|
||||
# Create lease for TTL if specified
|
||||
lease_id = None
|
||||
if keepalive_ttl is not None and keepalive_ttl > 0:
|
||||
lease_id = self._create_lease(keepalive_ttl)
|
||||
# Encode the string value to bytes
|
||||
self._client.put(name, value.encode("utf-8"), lease=lease_id)
|
||||
else:
|
||||
# Encode the string value to bytes
|
||||
self._client.put(name, value.encode("utf-8"))
|
||||
|
||||
# Store entry information for keepalive management
|
||||
self._entries[name] = self._Entry(
|
||||
value=value,
|
||||
lease_id=lease_id,
|
||||
keepalive_ttl=keepalive_ttl,
|
||||
keeper=(
|
||||
timeutil.FrequencyControl(frequency_seconds=keepalive_ttl / 3)
|
||||
if keepalive_ttl
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def delete(self, name):
|
||||
"""Delete a key from etcd.
|
||||
|
||||
Args:
|
||||
name: Key to delete
|
||||
|
||||
Raises:
|
||||
NameEntryNotFoundError: If the key doesn't exist
|
||||
"""
|
||||
with self._lock:
|
||||
self._delete_locked(name)
|
||||
|
||||
def _delete_locked(self, name):
|
||||
"""Delete a key from etcd with lock already acquired.
|
||||
|
||||
Args:
|
||||
name: Key to delete
|
||||
|
||||
Raises:
|
||||
NameEntryNotFoundError: If the key doesn't exist
|
||||
"""
|
||||
# First check if the key exists
|
||||
value, _ = self._client.get(name)
|
||||
if value is None:
|
||||
raise NameEntryNotFoundError(f"No such etcd entry to delete: {name}")
|
||||
|
||||
# Clean up entry tracking
|
||||
if name in self._entries:
|
||||
del self._entries[name]
|
||||
|
||||
# Delete from etcd
|
||||
self._client.delete(name)
|
||||
|
||||
def clear_subtree(self, name_root):
|
||||
"""Delete all keys with the given prefix.
|
||||
|
||||
Args:
|
||||
name_root: Prefix to match keys against
|
||||
"""
|
||||
with self._lock:
|
||||
count = 0
|
||||
name_root = name_root.rstrip("/")
|
||||
# Get all keys with the prefix
|
||||
for key_metadata_tuple in self._client.get_prefix(name_root):
|
||||
key = key_metadata_tuple[1].key.decode(
|
||||
"utf-8"
|
||||
) # Extract the key from metadata
|
||||
# Remove from our tracking
|
||||
if key in self._entries:
|
||||
del self._entries[key]
|
||||
# Delete from etcd
|
||||
self._client.delete(key)
|
||||
count += 1
|
||||
|
||||
logger.debug(f"Deleted {count} etcd entries under {name_root}")
|
||||
|
||||
def get_subtree(self, name_root):
|
||||
"""Get all values with keys having the given prefix.
|
||||
|
||||
Args:
|
||||
name_root: Prefix to match keys against
|
||||
|
||||
Returns:
|
||||
List of values
|
||||
"""
|
||||
with self._lock:
|
||||
rs = []
|
||||
name_root = name_root.rstrip("/")
|
||||
for value_metadata_tuple in self._client.get_prefix(name_root):
|
||||
value = value_metadata_tuple[0].decode("utf-8") # Extract the value
|
||||
rs.append(value)
|
||||
return sorted(rs)
|
||||
|
||||
def find_subtree(self, name_root):
|
||||
"""Find all keys with the given prefix.
|
||||
|
||||
Args:
|
||||
name_root: Prefix to match keys against
|
||||
|
||||
Returns:
|
||||
List of keys
|
||||
"""
|
||||
with self._lock:
|
||||
rs = []
|
||||
for key_metadata_tuple in self._client.get_prefix(name_root):
|
||||
key = key_metadata_tuple[1].key.decode(
|
||||
"utf-8"
|
||||
) # Extract the key from metadata
|
||||
rs.append(key)
|
||||
return sorted(rs)
|
||||
|
||||
def get(self, name):
|
||||
"""Get the value for a key.
|
||||
|
||||
Args:
|
||||
name: Key to retrieve
|
||||
|
||||
Returns:
|
||||
The value as a string
|
||||
|
||||
Raises:
|
||||
NameEntryNotFoundError: If the key doesn't exist
|
||||
"""
|
||||
with self._lock:
|
||||
return self._get_locked(name)
|
||||
|
||||
def _get_locked(self, name):
|
||||
"""Get a value with lock already acquired.
|
||||
|
||||
Args:
|
||||
name: Key to retrieve
|
||||
|
||||
Returns:
|
||||
The value as a string
|
||||
|
||||
Raises:
|
||||
NameEntryNotFoundError: If the key doesn't exist
|
||||
"""
|
||||
value, _ = self._client.get(name)
|
||||
if value is None:
|
||||
raise NameEntryNotFoundError(f"No such etcd entry: {name}")
|
||||
return value.decode("utf-8")
|
||||
|
||||
def reset(self):
|
||||
"""Delete all keys added via this repository instance."""
|
||||
with self._lock:
|
||||
count = 0
|
||||
for name in list(self._entries):
|
||||
try:
|
||||
self._delete_locked(name)
|
||||
count += 1
|
||||
except NameEntryNotFoundError:
|
||||
pass
|
||||
self._entries = {}
|
||||
logger.info(f"Reset {count} saved etcd entries")
|
||||
|
||||
def _keepalive_thread_run(self):
|
||||
"""Background thread to keep leases alive."""
|
||||
while self._keepalive_running:
|
||||
time.sleep(self.KEEPALIVE_POLL_FREQUENCY)
|
||||
with self._lock:
|
||||
for name, entry in list(self._entries.items()):
|
||||
if (
|
||||
entry.keeper is not None
|
||||
and entry.keepalive_ttl is not None
|
||||
and entry.lease_id is not None
|
||||
and entry.keeper.check()
|
||||
):
|
||||
try:
|
||||
# Refresh the lease
|
||||
self._client.refresh_lease(entry.lease_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to refresh lease for key: K={name} V={entry.value}. Error: {e}"
|
||||
)
|
||||
|
||||
def watch_names(
|
||||
self,
|
||||
names: List,
|
||||
call_back: Callable,
|
||||
poll_frequency=15,
|
||||
wait_timeout=300,
|
||||
):
|
||||
"""Watch keys and call back when they are deleted.
|
||||
|
||||
Args:
|
||||
names: Keys to watch
|
||||
call_back: Function to call when any key is deleted
|
||||
poll_frequency: How often to check in seconds
|
||||
wait_timeout: Maximum time to wait for keys to exist
|
||||
"""
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
# Use etcd's native watch capability for more efficient watching
|
||||
for name in names:
|
||||
# First wait for the key to exist
|
||||
self.wait(name, timeout=wait_timeout, poll_frequency=poll_frequency)
|
||||
|
||||
# Start watching for key deletion
|
||||
watch_id = self._client.add_watch_callback(
|
||||
name, lambda event: self._watch_callback(event, call_back)
|
||||
)
|
||||
|
||||
# Store watch ID for cleanup
|
||||
if not hasattr(self, "_watch_ids"):
|
||||
self._watch_ids = []
|
||||
self._watch_ids.append(watch_id)
|
||||
|
||||
def _watch_callback(self, event, callback):
|
||||
"""Process watch events and call back on deletion.
|
||||
|
||||
Args:
|
||||
event: The etcd watch response (WatchResponse object)
|
||||
callback: Function to call when a key is deleted
|
||||
"""
|
||||
# Iterate through the events in the WatchResponse
|
||||
for ev in event.events:
|
||||
# Check if this is a delete event
|
||||
if isinstance(ev, etcd3.events.DeleteEvent):
|
||||
logger.debug(f"Key {ev.key.decode()} was deleted. Executing callback.")
|
||||
callback()
|
||||
|
||||
def _testonly_drop_cached_entry(self, name):
|
||||
"""Used by unittest only to simulate the case that the process crashes.
|
||||
|
||||
Args:
|
||||
name: Key to drop from local cache
|
||||
"""
|
||||
with self._lock:
|
||||
if name in self._entries:
|
||||
del self._entries[name]
|
||||
logger.debug(f"Testonly: dropped key: {name}")
|
||||
|
||||
|
||||
def make_repository(type_="nfs", **kwargs):
|
||||
if type_ == "memory":
|
||||
return MemoryNameRecordRepository(**kwargs)
|
||||
|
@ -539,12 +889,16 @@ def make_repository(type_="nfs", **kwargs):
|
|||
return NfsNameRecordRepository(**kwargs)
|
||||
elif type_ == "redis":
|
||||
return RedisNameRecordRepository(**kwargs)
|
||||
elif type_ == "etcd3":
|
||||
return Etcd3NameRecordRepository(**kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"No such name resolver: {type_}")
|
||||
|
||||
|
||||
# DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs"
|
||||
DEFAULT_REPOSITORY_TYPE = "nfs"
|
||||
if etcd3 is not None and cluster.spec.name in ["wa180"]:
|
||||
DEFAULT_REPOSITORY_TYPE = "etcd3"
|
||||
DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE)
|
||||
add = DEFAULT_REPOSITORY.add
|
||||
add_subentry = DEFAULT_REPOSITORY.add_subentry
|
||||
|
|
|
@ -179,8 +179,10 @@ class ReaLModel(nn.Module):
|
|||
def save_to_hf(self, tokenizer, save_dir):
|
||||
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
|
||||
|
||||
def load_from_hf(self, load_dir):
|
||||
return getattr(self, f"from_{self.hf_model_family}")(load_dir)
|
||||
def load_from_hf(self, load_dir, init_critic_from_actor):
|
||||
return getattr(self, f"from_{self.hf_model_family}")(
|
||||
load_dir, init_critic_from_actor
|
||||
)
|
||||
|
||||
@property
|
||||
def pre_process(self):
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base import cluster, constants, logging
|
||||
from realhf.impl.model.utils.padding import pad_input, unpad_input
|
||||
|
||||
logger = logging.getLogger("Modeling Functional Utils")
|
||||
|
@ -166,7 +166,6 @@ def build_leave_one_indices(
|
|||
)
|
||||
|
||||
|
||||
@torch.compile
|
||||
def gather_logprobs(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
|
@ -187,6 +186,10 @@ def gather_logprobs(
|
|||
return log_probs_labels
|
||||
|
||||
|
||||
if cluster.spec.name != "wa180":
|
||||
gather_logprobs = torch.compile(gather_logprobs)
|
||||
|
||||
|
||||
def gather_packed_shifted_log_probs(
|
||||
logits: torch.FloatTensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
|
|
|
@ -624,6 +624,7 @@ class RayController:
|
|||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
|
||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
|
||||
)
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
|
|
|
@ -383,7 +383,7 @@ class NameResolvingReplyServer:
|
|||
experiment_name, trial_name, PUBSUB_BARRIER_NAME
|
||||
),
|
||||
value=socket.gethostbyname(socket.gethostname()),
|
||||
keepalive_ttl=60,
|
||||
keepalive_ttl=1200,
|
||||
)
|
||||
|
||||
def accept(self, server_send_addr: str, server_recv_addr: str):
|
||||
|
|
|
@ -20,7 +20,7 @@ logger = logging.getLogger("worker")
|
|||
|
||||
_MAX_SOCKET_CONCURRENCY = 1000
|
||||
WORKER_WAIT_FOR_CONTROLLER_SECONDS = 3600
|
||||
WORKER_JOB_STATUS_LINGER_SECONDS = 60
|
||||
WORKER_JOB_STATUS_LINGER_SECONDS = 1800
|
||||
|
||||
|
||||
class WorkerException(Exception):
|
||||
|
@ -127,7 +127,7 @@ class WorkerServer:
|
|||
if experiment_name is not None and trial_name is not None:
|
||||
key = names.worker(experiment_name, trial_name, worker_name)
|
||||
address = f"{host_ip}:{self.__task_queue.port}"
|
||||
name_resolve.add(key, address, keepalive_ttl=10, delete_on_exit=True)
|
||||
name_resolve.add(key, address, keepalive_ttl=1200, delete_on_exit=True)
|
||||
logger.debug(
|
||||
"Added name_resolve entry %s for worker server at %s",
|
||||
key,
|
||||
|
|
|
@ -54,3 +54,5 @@ cookiecutter>2.1.1
|
|||
asyncio
|
||||
aiohttp
|
||||
httpx>=0.28.1
|
||||
etcd3
|
||||
protobuf<3.21
|
|
@ -0,0 +1,151 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
import etcd3
|
||||
import pytest
|
||||
|
||||
from realhf.base.name_resolve import (
|
||||
Etcd3NameRecordRepository,
|
||||
NameEntryExistsError,
|
||||
NameEntryNotFoundError,
|
||||
)
|
||||
|
||||
host, port = os.getenv("REAL_ETCD_ADDR", "localhost:2379").split(":")
|
||||
port = int(port)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def etcd_client():
|
||||
client = etcd3.client(host=host, port=port)
|
||||
yield client
|
||||
# Clean up etcd after each test
|
||||
client.delete_prefix("test_") # Delete all keys
|
||||
|
||||
|
||||
# Fixture to provide an instance of Etcd3NameRecordRepository
|
||||
@pytest.fixture
|
||||
def etcd_repo():
|
||||
repo = Etcd3NameRecordRepository(host=host, port=port)
|
||||
yield repo
|
||||
repo.reset() # Clean up repository after each test
|
||||
|
||||
|
||||
def test_add(etcd_repo):
|
||||
# Test adding a new key-value pair
|
||||
etcd_repo.add("test_key", "test_value")
|
||||
value, _ = etcd_repo._client.get("test_key")
|
||||
assert value.decode("utf-8") == "test_value"
|
||||
|
||||
# Test adding a key that already exists without replace
|
||||
with pytest.raises(NameEntryExistsError):
|
||||
etcd_repo.add("test_key", "new_value", replace=False)
|
||||
|
||||
# Test adding a key that already exists with replace
|
||||
etcd_repo.add("test_key", "new_value", replace=True)
|
||||
value, _ = etcd_repo._client.get("test_key")
|
||||
assert value.decode("utf-8") == "new_value"
|
||||
|
||||
|
||||
def test_delete(etcd_repo):
|
||||
# Test deleting an existing key
|
||||
etcd_repo.add("test_key", "test_value")
|
||||
etcd_repo.delete("test_key")
|
||||
value, _ = etcd_repo._client.get("test_key")
|
||||
assert value is None
|
||||
|
||||
# Test deleting a non-existent key
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
etcd_repo.delete("non_existent_key")
|
||||
|
||||
|
||||
def test_clear_subtree(etcd_repo):
|
||||
# Test clearing a subtree
|
||||
etcd_repo.add("test_key/sub1", "value1")
|
||||
etcd_repo.add("test_key/sub2", "value2")
|
||||
etcd_repo.clear_subtree("test_key")
|
||||
value1, _ = etcd_repo._client.get("test_key/sub1")
|
||||
value2, _ = etcd_repo._client.get("test_key/sub2")
|
||||
assert value1 is None
|
||||
assert value2 is None
|
||||
|
||||
|
||||
def test_get(etcd_repo):
|
||||
# Test getting an existing key
|
||||
etcd_repo.add("test_key", "test_value")
|
||||
assert etcd_repo.get("test_key") == "test_value"
|
||||
|
||||
# Test getting a non-existent key
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
etcd_repo.get("non_existent_key")
|
||||
|
||||
|
||||
def test_get_subtree(etcd_repo):
|
||||
# Test getting values from a subtree
|
||||
etcd_repo.add("test_key/sub1", "value1")
|
||||
etcd_repo.add("test_key/sub2", "value2")
|
||||
assert etcd_repo.get_subtree("test_key") == ["value1", "value2"]
|
||||
|
||||
|
||||
def test_find_subtree(etcd_repo):
|
||||
# Test finding keys in a subtree
|
||||
etcd_repo.add("test_key/sub1", "value1")
|
||||
etcd_repo.add("test_key/sub2", "value2")
|
||||
assert etcd_repo.find_subtree("test_key") == ["test_key/sub1", "test_key/sub2"]
|
||||
|
||||
|
||||
def test_reset(etcd_repo):
|
||||
# Test resetting the repository
|
||||
etcd_repo.add("test_key1", "value1", delete_on_exit=True)
|
||||
etcd_repo.add("test_key2", "value2", delete_on_exit=True)
|
||||
etcd_repo.reset()
|
||||
value1, _ = etcd_repo._client.get("test_key1")
|
||||
value2, _ = etcd_repo._client.get("test_key2")
|
||||
assert value1 is None
|
||||
assert value2 is None
|
||||
|
||||
|
||||
def test_watch_names(etcd_repo):
|
||||
# Test watching keys
|
||||
callback_called = False
|
||||
|
||||
def callback():
|
||||
nonlocal callback_called
|
||||
callback_called = True
|
||||
|
||||
etcd_repo.add("test_key", "test_value")
|
||||
etcd_repo.watch_names(["test_key"], callback)
|
||||
|
||||
# Delete the key to trigger the callback
|
||||
etcd_repo.delete("test_key")
|
||||
time.sleep(1) # Give the watcher time to trigger
|
||||
assert callback_called
|
||||
|
||||
|
||||
def test_keepalive_thread(etcd_repo):
|
||||
# Test the keepalive thread
|
||||
etcd_repo.add("test_key", "test_value", keepalive_ttl=2)
|
||||
time.sleep(1) # Wait for the keepalive thread to refresh the lease
|
||||
# Ensure the key still exists
|
||||
value, _ = etcd_repo._client.get("test_key")
|
||||
assert value.decode("utf-8") == "test_value"
|
||||
time.sleep(2) # Wait for the lease to expire
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
etcd_repo.get("test_key")
|
||||
|
||||
|
||||
def test_context_manager(etcd_repo):
|
||||
# Test the context manager
|
||||
with etcd_repo as repo:
|
||||
repo.add("test_key", "test_value", delete_on_exit=True)
|
||||
assert repo.get("test_key") == "test_value"
|
||||
# Ensure the key is deleted after exiting the context
|
||||
value, _ = etcd_repo._client.get("test_key")
|
||||
assert value is None
|
||||
|
||||
|
||||
def test_del(etcd_repo, etcd_client):
|
||||
# Test the destructor
|
||||
etcd_repo.add("test_key", "test_value", delete_on_exit=True)
|
||||
etcd_repo.__del__()
|
||||
value, _ = etcd_client.get("test_key")
|
||||
assert value is None
|
|
@ -95,14 +95,13 @@ def run_test_exp(
|
|||
testcase.start()
|
||||
|
||||
# Run the master worker.
|
||||
for _ in range(100):
|
||||
for _ in range(100):
|
||||
if mas.status == WorkerServerStatus.PAUSED:
|
||||
break
|
||||
if not initd:
|
||||
logger.info("Running master worker lazy initialization...")
|
||||
mas._poll()
|
||||
if not initd:
|
||||
logger.info("Running master worker lazy initialization... Done.")
|
||||
initd = True
|
||||
testcase.wait(timeout=0.1)
|
||||
for _ in range(int(1e4)):
|
||||
if mas.status == WorkerServerStatus.PAUSED:
|
||||
break
|
||||
if not initd:
|
||||
logger.info("Running master worker lazy initialization...")
|
||||
mas._poll()
|
||||
if not initd:
|
||||
logger.info("Running master worker lazy initialization... Done.")
|
||||
initd = True
|
||||
testcase.wait(timeout=0.1)
|
||||
|
|
Loading…
Reference in New Issue