AReaL/realhf/base/network.py

74 lines
2.3 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import fcntl
import os
import socket
import time
from contextlib import closing
from functools import wraps
from realhf.base import constants, logging, name_resolve, names
logger = logging.getLogger(__name__)
def gethostname():
return socket.gethostname()
def gethostip():
return socket.gethostbyname(socket.gethostname())
def find_free_port(
low=1, high=65536, exclude_ports=None, experiment_name="port", trial_name="port"
):
"""Find a free port within the specified range, excluding certain ports."""
ports_name = names.used_ports(experiment_name, trial_name, gethostip())
used_ports = list(map(int, name_resolve.get_subtree(ports_name)))
if exclude_ports is None:
exclude_ports = set(used_ports)
else:
exclude_ports = exclude_ports.union(set(used_ports))
free_port = None
lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip())
while True:
with open(lockfile, "w") as fd:
# This will block until lock is acquired
fcntl.flock(fd, fcntl.LOCK_EX)
try:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
if low <= port <= high and port not in exclude_ports:
name_resolve.add_subentry(ports_name, str(port))
logger.info(f"Found free port {port}")
free_port = port
break
finally:
fcntl.flock(fd, fcntl.LOCK_UN)
time.sleep(0.05)
return free_port
def find_multiple_free_ports(
count, low=1, high=65536, experiment_name="port", trial_name="port"
):
"""Find multiple mutually exclusive free ports."""
free_ports = set()
for _ in range(count):
port = find_free_port(
low=low,
high=high,
exclude_ports=free_ports,
experiment_name=experiment_name,
trial_name=trial_name,
)
free_ports.add(port)
return list(free_ports)