PullRequest: 58 Support ETCD3 name resolving repo

Merge branch fw/etcd of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/58

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* .
This commit is contained in:
博惟 2025-03-25 16:05:04 +08:00
parent 6ccbb01ca8
commit 9f77f96580
11 changed files with 535 additions and 22 deletions

View File

@ -163,6 +163,7 @@ def main_start(args, recover_count: int = 0):
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
)
for k, v in BASE_ENVIRONS.items():
os.environ[k] = v

View File

@ -62,7 +62,7 @@ def reveal_pg_identity(expr_name, trial_name, worker_index):
master_group_name = names.distributed_peer(
expr_name, trial_name, GLOBAL_PROCESS_GROUP_NAME
)
name_resolve.add_subentry(master_group_name, str(worker_index), keepalive_ttl=30)
name_resolve.add_subentry(master_group_name, str(worker_index), keepalive_ttl=300)
def isolate_cuda_device(

View File

@ -4,18 +4,21 @@
# Implements a simple name resolving service, which can be considered as a distributed key-value dict.
import dataclasses
import getpass
import os
import queue
import random
import shutil
import socket
import threading
import time
import uuid
from typing import Callable, List, Optional
from realhf.base import logging, security, timeutil
try:
import etcd3
except Exception:
etcd3 = None
from realhf.base import cluster, logging, security, timeutil
from realhf.base.cluster import spec as cluster_spec
logger = logging.getLogger("name-resolve")
@ -532,6 +535,353 @@ class RedisNameRecordRepository(NameRecordRepository):
print("Testonly: dropped key:", name)
class Etcd3NameRecordRepository(NameRecordRepository):
"""Implements a name record repository using etcd3 as the backend storage.
This implementation provides distributed key-value storage with support for
TTL-based expiration, atomic operations, and key watching functionality.
"""
# Default configuration
host, port = os.getenv("REAL_ETCD_ADDR", "localhost:2379").split(":")
ETCD_HOST = host
ETCD_PORT = int(port)
ETCD_USER = None
ETCD_PASSWORD = None
KEEPALIVE_POLL_FREQUENCY = 1
@dataclasses.dataclass
class _Entry:
value: str
lease_id: Optional[int] = None
keepalive_ttl: Optional[int] = None
keeper: Optional[timeutil.FrequencyControl] = None
def __init__(self, host=None, port=None, user=None, password=None, **kwargs):
"""Initialize the etcd3 name record repository.
Args:
host: etcd server host (defaults to ETCD_HOST)
port: etcd server port (defaults to ETCD_PORT)
user: etcd username for authentication (defaults to ETCD_USER)
password: etcd password for authentication (defaults to ETCD_PASSWORD)
**kwargs: Additional configuration parameters
"""
super().__init__()
self._lock = threading.Lock()
# Set connection parameters
self._host = host or self.ETCD_HOST
self._port = port or self.ETCD_PORT
self._user = user or self.ETCD_USER
self._password = password or self.ETCD_PASSWORD
# Connect to etcd
self._client = etcd3.client(
host=self._host, port=self._port, user=self._user, password=self._password
)
# Keep track of entries for cleanup and keepalive
self._entries = {}
self._keepalive_running = True
self._keepalive_thread = threading.Thread(
target=self._keepalive_thread_run, daemon=True
)
self._keepalive_thread.start()
logger.info(f"Connected to etcd3 at {self._host}:{self._port}")
def __del__(self):
"""Clean up resources when the object is deleted."""
self._keepalive_running = False
if hasattr(self, "_keepalive_thread"):
self._keepalive_thread.join(timeout=5)
self.reset()
if hasattr(self, "_client"):
self._client.close()
def _create_lease(self, ttl_seconds):
"""Create an etcd lease with the specified TTL.
Args:
ttl_seconds: Time-to-live in seconds
Returns:
The lease ID
"""
lease = self._client.lease(ttl_seconds)
return lease.id
def add(
self,
name,
value,
delete_on_exit=True,
keepalive_ttl=300,
replace=False,
):
"""Add a key-value pair to etcd with optional TTL.
Args:
name: Key name
value: Value to store
delete_on_exit: Whether to delete the key when this object is destroyed
keepalive_ttl: TTL in seconds for the key (default: 10)
replace: Whether to replace an existing key
Raises:
NameEntryExistsError: If the key already exists and replace is False
"""
name = name.rstrip("/")
value = str(value)
with self._lock:
# Check if key exists when replace=False
if not replace:
existing_value, _ = self._client.get(name)
if existing_value is not None:
raise NameEntryExistsError(
f"Key already exists: K={name} V={existing_value.decode()}"
)
# Create lease for TTL if specified
lease_id = None
if keepalive_ttl is not None and keepalive_ttl > 0:
lease_id = self._create_lease(keepalive_ttl)
# Encode the string value to bytes
self._client.put(name, value.encode("utf-8"), lease=lease_id)
else:
# Encode the string value to bytes
self._client.put(name, value.encode("utf-8"))
# Store entry information for keepalive management
self._entries[name] = self._Entry(
value=value,
lease_id=lease_id,
keepalive_ttl=keepalive_ttl,
keeper=(
timeutil.FrequencyControl(frequency_seconds=keepalive_ttl / 3)
if keepalive_ttl
else None
),
)
def delete(self, name):
"""Delete a key from etcd.
Args:
name: Key to delete
Raises:
NameEntryNotFoundError: If the key doesn't exist
"""
with self._lock:
self._delete_locked(name)
def _delete_locked(self, name):
"""Delete a key from etcd with lock already acquired.
Args:
name: Key to delete
Raises:
NameEntryNotFoundError: If the key doesn't exist
"""
# First check if the key exists
value, _ = self._client.get(name)
if value is None:
raise NameEntryNotFoundError(f"No such etcd entry to delete: {name}")
# Clean up entry tracking
if name in self._entries:
del self._entries[name]
# Delete from etcd
self._client.delete(name)
def clear_subtree(self, name_root):
"""Delete all keys with the given prefix.
Args:
name_root: Prefix to match keys against
"""
with self._lock:
count = 0
name_root = name_root.rstrip("/")
# Get all keys with the prefix
for key_metadata_tuple in self._client.get_prefix(name_root):
key = key_metadata_tuple[1].key.decode(
"utf-8"
) # Extract the key from metadata
# Remove from our tracking
if key in self._entries:
del self._entries[key]
# Delete from etcd
self._client.delete(key)
count += 1
logger.debug(f"Deleted {count} etcd entries under {name_root}")
def get_subtree(self, name_root):
"""Get all values with keys having the given prefix.
Args:
name_root: Prefix to match keys against
Returns:
List of values
"""
with self._lock:
rs = []
name_root = name_root.rstrip("/")
for value_metadata_tuple in self._client.get_prefix(name_root):
value = value_metadata_tuple[0].decode("utf-8") # Extract the value
rs.append(value)
return sorted(rs)
def find_subtree(self, name_root):
"""Find all keys with the given prefix.
Args:
name_root: Prefix to match keys against
Returns:
List of keys
"""
with self._lock:
rs = []
for key_metadata_tuple in self._client.get_prefix(name_root):
key = key_metadata_tuple[1].key.decode(
"utf-8"
) # Extract the key from metadata
rs.append(key)
return sorted(rs)
def get(self, name):
"""Get the value for a key.
Args:
name: Key to retrieve
Returns:
The value as a string
Raises:
NameEntryNotFoundError: If the key doesn't exist
"""
with self._lock:
return self._get_locked(name)
def _get_locked(self, name):
"""Get a value with lock already acquired.
Args:
name: Key to retrieve
Returns:
The value as a string
Raises:
NameEntryNotFoundError: If the key doesn't exist
"""
value, _ = self._client.get(name)
if value is None:
raise NameEntryNotFoundError(f"No such etcd entry: {name}")
return value.decode("utf-8")
def reset(self):
"""Delete all keys added via this repository instance."""
with self._lock:
count = 0
for name in list(self._entries):
try:
self._delete_locked(name)
count += 1
except NameEntryNotFoundError:
pass
self._entries = {}
logger.info(f"Reset {count} saved etcd entries")
def _keepalive_thread_run(self):
"""Background thread to keep leases alive."""
while self._keepalive_running:
time.sleep(self.KEEPALIVE_POLL_FREQUENCY)
with self._lock:
for name, entry in list(self._entries.items()):
if (
entry.keeper is not None
and entry.keepalive_ttl is not None
and entry.lease_id is not None
and entry.keeper.check()
):
try:
# Refresh the lease
self._client.refresh_lease(entry.lease_id)
except Exception as e:
logger.error(
f"Failed to refresh lease for key: K={name} V={entry.value}. Error: {e}"
)
def watch_names(
self,
names: List,
call_back: Callable,
poll_frequency=15,
wait_timeout=300,
):
"""Watch keys and call back when they are deleted.
Args:
names: Keys to watch
call_back: Function to call when any key is deleted
poll_frequency: How often to check in seconds
wait_timeout: Maximum time to wait for keys to exist
"""
if isinstance(names, str):
names = [names]
# Use etcd's native watch capability for more efficient watching
for name in names:
# First wait for the key to exist
self.wait(name, timeout=wait_timeout, poll_frequency=poll_frequency)
# Start watching for key deletion
watch_id = self._client.add_watch_callback(
name, lambda event: self._watch_callback(event, call_back)
)
# Store watch ID for cleanup
if not hasattr(self, "_watch_ids"):
self._watch_ids = []
self._watch_ids.append(watch_id)
def _watch_callback(self, event, callback):
"""Process watch events and call back on deletion.
Args:
event: The etcd watch response (WatchResponse object)
callback: Function to call when a key is deleted
"""
# Iterate through the events in the WatchResponse
for ev in event.events:
# Check if this is a delete event
if isinstance(ev, etcd3.events.DeleteEvent):
logger.debug(f"Key {ev.key.decode()} was deleted. Executing callback.")
callback()
def _testonly_drop_cached_entry(self, name):
"""Used by unittest only to simulate the case that the process crashes.
Args:
name: Key to drop from local cache
"""
with self._lock:
if name in self._entries:
del self._entries[name]
logger.debug(f"Testonly: dropped key: {name}")
def make_repository(type_="nfs", **kwargs):
if type_ == "memory":
return MemoryNameRecordRepository(**kwargs)
@ -539,12 +889,16 @@ def make_repository(type_="nfs", **kwargs):
return NfsNameRecordRepository(**kwargs)
elif type_ == "redis":
return RedisNameRecordRepository(**kwargs)
elif type_ == "etcd3":
return Etcd3NameRecordRepository(**kwargs)
else:
raise NotImplementedError(f"No such name resolver: {type_}")
# DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs"
DEFAULT_REPOSITORY_TYPE = "nfs"
if etcd3 is not None and cluster.spec.name in ["wa180"]:
DEFAULT_REPOSITORY_TYPE = "etcd3"
DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE)
add = DEFAULT_REPOSITORY.add
add_subentry = DEFAULT_REPOSITORY.add_subentry

View File

@ -179,8 +179,10 @@ class ReaLModel(nn.Module):
def save_to_hf(self, tokenizer, save_dir):
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
def load_from_hf(self, load_dir):
return getattr(self, f"from_{self.hf_model_family}")(load_dir)
def load_from_hf(self, load_dir, init_critic_from_actor):
return getattr(self, f"from_{self.hf_model_family}")(
load_dir, init_critic_from_actor
)
@property
def pre_process(self):

View File

@ -9,7 +9,7 @@ import torch
import torch.distributed as dist
import transformers
from realhf.base import constants, logging
from realhf.base import cluster, constants, logging
from realhf.impl.model.utils.padding import pad_input, unpad_input
logger = logging.getLogger("Modeling Functional Utils")
@ -166,7 +166,6 @@ def build_leave_one_indices(
)
@torch.compile
def gather_logprobs(
logits: torch.Tensor,
labels: torch.Tensor,
@ -187,6 +186,10 @@ def gather_logprobs(
return log_probs_labels
if cluster.spec.name != "wa180":
gather_logprobs = torch.compile(gather_logprobs)
def gather_packed_shifted_log_probs(
logits: torch.FloatTensor,
cu_seqlens: torch.Tensor,

View File

@ -624,6 +624,7 @@ class RayController:
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
)
runtime_env = {
"env_vars": env_vars,

View File

@ -383,7 +383,7 @@ class NameResolvingReplyServer:
experiment_name, trial_name, PUBSUB_BARRIER_NAME
),
value=socket.gethostbyname(socket.gethostname()),
keepalive_ttl=60,
keepalive_ttl=1200,
)
def accept(self, server_send_addr: str, server_recv_addr: str):

View File

@ -20,7 +20,7 @@ logger = logging.getLogger("worker")
_MAX_SOCKET_CONCURRENCY = 1000
WORKER_WAIT_FOR_CONTROLLER_SECONDS = 3600
WORKER_JOB_STATUS_LINGER_SECONDS = 60
WORKER_JOB_STATUS_LINGER_SECONDS = 1800
class WorkerException(Exception):
@ -127,7 +127,7 @@ class WorkerServer:
if experiment_name is not None and trial_name is not None:
key = names.worker(experiment_name, trial_name, worker_name)
address = f"{host_ip}:{self.__task_queue.port}"
name_resolve.add(key, address, keepalive_ttl=10, delete_on_exit=True)
name_resolve.add(key, address, keepalive_ttl=1200, delete_on_exit=True)
logger.debug(
"Added name_resolve entry %s for worker server at %s",
key,

View File

@ -54,3 +54,5 @@ cookiecutter>2.1.1
asyncio
aiohttp
httpx>=0.28.1
etcd3
protobuf<3.21

View File

@ -0,0 +1,151 @@
import os
import time
import etcd3
import pytest
from realhf.base.name_resolve import (
Etcd3NameRecordRepository,
NameEntryExistsError,
NameEntryNotFoundError,
)
host, port = os.getenv("REAL_ETCD_ADDR", "localhost:2379").split(":")
port = int(port)
@pytest.fixture
def etcd_client():
client = etcd3.client(host=host, port=port)
yield client
# Clean up etcd after each test
client.delete_prefix("test_") # Delete all keys
# Fixture to provide an instance of Etcd3NameRecordRepository
@pytest.fixture
def etcd_repo():
repo = Etcd3NameRecordRepository(host=host, port=port)
yield repo
repo.reset() # Clean up repository after each test
def test_add(etcd_repo):
# Test adding a new key-value pair
etcd_repo.add("test_key", "test_value")
value, _ = etcd_repo._client.get("test_key")
assert value.decode("utf-8") == "test_value"
# Test adding a key that already exists without replace
with pytest.raises(NameEntryExistsError):
etcd_repo.add("test_key", "new_value", replace=False)
# Test adding a key that already exists with replace
etcd_repo.add("test_key", "new_value", replace=True)
value, _ = etcd_repo._client.get("test_key")
assert value.decode("utf-8") == "new_value"
def test_delete(etcd_repo):
# Test deleting an existing key
etcd_repo.add("test_key", "test_value")
etcd_repo.delete("test_key")
value, _ = etcd_repo._client.get("test_key")
assert value is None
# Test deleting a non-existent key
with pytest.raises(NameEntryNotFoundError):
etcd_repo.delete("non_existent_key")
def test_clear_subtree(etcd_repo):
# Test clearing a subtree
etcd_repo.add("test_key/sub1", "value1")
etcd_repo.add("test_key/sub2", "value2")
etcd_repo.clear_subtree("test_key")
value1, _ = etcd_repo._client.get("test_key/sub1")
value2, _ = etcd_repo._client.get("test_key/sub2")
assert value1 is None
assert value2 is None
def test_get(etcd_repo):
# Test getting an existing key
etcd_repo.add("test_key", "test_value")
assert etcd_repo.get("test_key") == "test_value"
# Test getting a non-existent key
with pytest.raises(NameEntryNotFoundError):
etcd_repo.get("non_existent_key")
def test_get_subtree(etcd_repo):
# Test getting values from a subtree
etcd_repo.add("test_key/sub1", "value1")
etcd_repo.add("test_key/sub2", "value2")
assert etcd_repo.get_subtree("test_key") == ["value1", "value2"]
def test_find_subtree(etcd_repo):
# Test finding keys in a subtree
etcd_repo.add("test_key/sub1", "value1")
etcd_repo.add("test_key/sub2", "value2")
assert etcd_repo.find_subtree("test_key") == ["test_key/sub1", "test_key/sub2"]
def test_reset(etcd_repo):
# Test resetting the repository
etcd_repo.add("test_key1", "value1", delete_on_exit=True)
etcd_repo.add("test_key2", "value2", delete_on_exit=True)
etcd_repo.reset()
value1, _ = etcd_repo._client.get("test_key1")
value2, _ = etcd_repo._client.get("test_key2")
assert value1 is None
assert value2 is None
def test_watch_names(etcd_repo):
# Test watching keys
callback_called = False
def callback():
nonlocal callback_called
callback_called = True
etcd_repo.add("test_key", "test_value")
etcd_repo.watch_names(["test_key"], callback)
# Delete the key to trigger the callback
etcd_repo.delete("test_key")
time.sleep(1) # Give the watcher time to trigger
assert callback_called
def test_keepalive_thread(etcd_repo):
# Test the keepalive thread
etcd_repo.add("test_key", "test_value", keepalive_ttl=2)
time.sleep(1) # Wait for the keepalive thread to refresh the lease
# Ensure the key still exists
value, _ = etcd_repo._client.get("test_key")
assert value.decode("utf-8") == "test_value"
time.sleep(2) # Wait for the lease to expire
with pytest.raises(NameEntryNotFoundError):
etcd_repo.get("test_key")
def test_context_manager(etcd_repo):
# Test the context manager
with etcd_repo as repo:
repo.add("test_key", "test_value", delete_on_exit=True)
assert repo.get("test_key") == "test_value"
# Ensure the key is deleted after exiting the context
value, _ = etcd_repo._client.get("test_key")
assert value is None
def test_del(etcd_repo, etcd_client):
# Test the destructor
etcd_repo.add("test_key", "test_value", delete_on_exit=True)
etcd_repo.__del__()
value, _ = etcd_client.get("test_key")
assert value is None

View File

@ -95,14 +95,13 @@ def run_test_exp(
testcase.start()
# Run the master worker.
for _ in range(100):
for _ in range(100):
if mas.status == WorkerServerStatus.PAUSED:
break
if not initd:
logger.info("Running master worker lazy initialization...")
mas._poll()
if not initd:
logger.info("Running master worker lazy initialization... Done.")
initd = True
testcase.wait(timeout=0.1)
for _ in range(int(1e4)):
if mas.status == WorkerServerStatus.PAUSED:
break
if not initd:
logger.info("Running master worker lazy initialization...")
mas._poll()
if not initd:
logger.info("Running master worker lazy initialization... Done.")
initd = True
testcase.wait(timeout=0.1)