diff --git a/realhf/system/generation_server.py b/realhf/system/generation_server.py index 96c043d..9bf39dc 100644 --- a/realhf/system/generation_server.py +++ b/realhf/system/generation_server.py @@ -41,11 +41,7 @@ def execute_shell_command(command: str) -> subprocess.Popen: ) -def launch_server_cmd(command: str, port: int = 30000): - """ - Launch the server using the given command. - If no port is specified, a free port is reserved. - """ +def apply_sglang_path(): p = Path(os.path.dirname(__file__)) patch_path = str( p.parent.parent @@ -72,6 +68,15 @@ def launch_server_cmd(command: str, port: int = 30000): ) proc.wait() logger.info(f"Applied SGLang patch at {target_path}") + + +def launch_server_cmd(command: str, port: int = 30000): + """ + Launch the server using the given command. + If no port is specified, a free port is reserved. + """ + if not ray.is_initialized(): + apply_sglang_path() assert port is not None full_command = f"{command} --port {port}" process = execute_shell_command(full_command)