mirror of https://github.com/inclusionAI/AReaL
PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration
Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * .
This commit is contained in:
parent
d8038b2669
commit
724628eaf0
|
@ -92,7 +92,7 @@ def main_grpo():
|
|||
future.result()
|
||||
|
||||
# synchronous rollout
|
||||
rollout_batch = rollout.rollout(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
||||
rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
||||
# or asynchronous rollout with filtering and off-policyness control
|
||||
# rollout_batch = rollout.prepare_batch(batch,
|
||||
# workflow=MyRolloutWorkflow(rollout_config.workflow),
|
||||
|
@ -697,7 +697,7 @@ reward = TrainController(Critic())
|
|||
rollout_controller = RolloutController(...)
|
||||
for _ in range(epochs):
|
||||
for _ in range(steps_per_epoch):
|
||||
data = rollout_controller.rollout(prompt)
|
||||
data = rollout_controller.rollout_batch(prompt)
|
||||
data['reward'] = reward.compute_values(data)
|
||||
...
|
||||
```
|
||||
|
|
|
@ -199,6 +199,9 @@ class SGLangConfig:
|
|||
https://github.com/sgl-project/sglang for detailed documentation.
|
||||
"""
|
||||
|
||||
model_path: str = ""
|
||||
random_seed: int = 1
|
||||
skip_tokenizer_init: bool = False
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
|
@ -234,10 +237,8 @@ class SGLangConfig:
|
|||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
|
||||
dtype: str = "float16"
|
||||
kv_cache_dtype: str = "auto"
|
||||
|
||||
# logging
|
||||
log_level: str = "warning"
|
||||
log_level_http: Optional[str] = "warning"
|
||||
|
@ -253,55 +254,60 @@ class SGLangConfig:
|
|||
@staticmethod
|
||||
def build_cmd(
|
||||
sglang_config: "SGLangConfig",
|
||||
model_path,
|
||||
tp_size,
|
||||
base_gpu_id,
|
||||
host,
|
||||
port,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
served_model_name: Optional[str] = None,
|
||||
skip_tokenizer_init: bool = True,
|
||||
sglang_version: Optional[str] = None,
|
||||
):
|
||||
from realhf.base import network, pkg_version, seeding
|
||||
from realhf.base import pkg_version
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
||||
args: Dict = conf_as_dict(sglang_config)
|
||||
args["random_seed"] = seeding.get_seed()
|
||||
|
||||
if served_model_name is None:
|
||||
served_model_name = model_path
|
||||
host_ip = network.gethostip()
|
||||
host = "localhost" if not sglang_config.enable_metrics else host_ip
|
||||
args = dict(
|
||||
host=host,
|
||||
model_path=model_path,
|
||||
port=port,
|
||||
# Model and tokenizer
|
||||
tokenizer_path=model_path,
|
||||
tokenizer_path=sglang_config.model_path,
|
||||
tokenizer_mode="auto",
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
device="cuda",
|
||||
served_model_name=served_model_name,
|
||||
is_embedding=False,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
# Other runtime options
|
||||
tp_size=tp_size,
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
base_gpu_id=base_gpu_id,
|
||||
nnodes=1,
|
||||
node_rank=0,
|
||||
# initialization addresses and ports
|
||||
dist_init_addr=dist_init_addr,
|
||||
**args,
|
||||
)
|
||||
|
||||
if pkg_version.is_version_less("sglang", "0.4.4"):
|
||||
if sglang_version:
|
||||
version_less_than_0_4_4 = (
|
||||
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
|
||||
)
|
||||
version_less_than_0_4_3 = (
|
||||
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
|
||||
)
|
||||
elif pkg_version.is_available("sglang"):
|
||||
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
|
||||
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
|
||||
else:
|
||||
raise ValueError(
|
||||
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
|
||||
)
|
||||
if version_less_than_0_4_4:
|
||||
args.pop("log_requests_level")
|
||||
if pkg_version.is_version_less("sglang", "0.4.3"):
|
||||
if version_less_than_0_4_3:
|
||||
args.pop("enable_nccl_nvls")
|
||||
args.pop("triton_attention_num_kv_splits")
|
||||
args.pop("cuda_graph_bs")
|
||||
args.pop("enable_memory_saver")
|
||||
args.pop("allow_auto_truncate")
|
||||
args.pop("file_storage_path")
|
||||
|
||||
flags = []
|
||||
for k, v in args.items():
|
||||
if v is None or v is False or v == "":
|
||||
|
@ -320,8 +326,8 @@ class SGLangConfig:
|
|||
|
||||
@dataclass
|
||||
class InferenceEngineConfig:
|
||||
experiment_name: str
|
||||
trial_name: str
|
||||
experiment_name: str = MISSING
|
||||
trial_name: str = MISSING
|
||||
max_concurrent_rollouts: None | int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
@ -345,27 +351,20 @@ class InferenceEngineConfig:
|
|||
},
|
||||
)
|
||||
# Used by remote inference engines.
|
||||
server_addrs: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of server addresses for inference."},
|
||||
)
|
||||
enable_rollout_tracing: bool = field(default=False)
|
||||
schedule_policy: str = field(
|
||||
default="round_robin",
|
||||
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
|
||||
)
|
||||
setup_timeout: float = field(default=90.0)
|
||||
request_timeout: float = field(
|
||||
default=30.0, metadata={"help": "Timeout for HTTP requests."}
|
||||
default=3600, metadata={"help": "Timeout for HTTP requests."}
|
||||
)
|
||||
request_retries: int = field(
|
||||
default=3, metadata={"help": "Number of retries for failed requests."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangEngineConfig:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Timer:
|
||||
experiment_name: str = MISSING
|
||||
|
@ -595,42 +594,53 @@ class BaseExperimentConfig:
|
|||
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
|
||||
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
|
||||
|
||||
server_only: bool = False
|
||||
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(BaseExperimentConfig):
|
||||
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
def parse_cli_args(argv: List[str]):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", help="The path of the main configuration file", required=True
|
||||
)
|
||||
args, overrides = parser.parse_known_args(argv)
|
||||
|
||||
# Initialize hydra config
|
||||
config_file = Path(args.config).absolute()
|
||||
assert config_file.exists()
|
||||
# hydra only recognize relative paths
|
||||
relpath = Path(
|
||||
os.path.relpath(str(config_file), (Path(__file__).parent).absolute())
|
||||
)
|
||||
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
|
||||
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
|
||||
cfg = hydra_compose(
|
||||
config_name=str(relpath.name).rstrip(".yaml"),
|
||||
config_name=str(relpath.name).split(".yaml")[0],
|
||||
overrides=overrides,
|
||||
)
|
||||
return cfg, config_file
|
||||
|
||||
|
||||
def to_structured_cfg(cfg, config_cls):
|
||||
# Merge with the default configuration.
|
||||
# The yaml and commandline can omit some default values defined in python dataclasses.
|
||||
default_cfg = OmegaConf.structured(config_cls)
|
||||
cfg = OmegaConf.merge(default_cfg, cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls):
|
||||
cfg, config_file = parse_cli_args(argv)
|
||||
cfg = to_structured_cfg(cfg, config_cls=config_cls)
|
||||
cfg = OmegaConf.to_object(cfg)
|
||||
assert isinstance(cfg, BaseExperimentConfig)
|
||||
|
||||
# Setup environment
|
||||
from realhf.base import constants, name_resolve
|
||||
from realhf.base import constants, name_resolve, names
|
||||
|
||||
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
|
||||
name_resolve.reconfigure(cfg.cluster.name_resolve)
|
||||
name_resolve.clear_subtree(
|
||||
names.trial_root(experiment_name=cfg.experiment_name, trial_name=cfg.trial_name)
|
||||
)
|
||||
return cfg, str(config_file)
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
|||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
|
@ -77,9 +78,9 @@ class TrainEngine(abc.ABC):
|
|||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> Dict[str, float]:
|
||||
"""Update the model with a batch of data and a loss function."""
|
||||
raise NotImplementedError()
|
||||
|
@ -87,9 +88,9 @@ class TrainEngine(abc.ABC):
|
|||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[TensorDict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate the model using the forward pass and loss function."""
|
||||
raise NotImplementedError()
|
||||
|
@ -97,9 +98,9 @@ class TrainEngine(abc.ABC):
|
|||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
|
||||
|
@ -127,12 +128,33 @@ class InferenceEngine(abc.ABC):
|
|||
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def wait(self, count: int, timeout: float) -> TensorDict:
|
||||
def wait(
|
||||
self,
|
||||
count: int,
|
||||
timeout: float | None = None,
|
||||
should_accept: Callable | None = None,
|
||||
) -> TensorDict:
|
||||
"""Wait for a specified number of requests to complete, with a timeout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rollout(
|
||||
def rollout_batch(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def prepare_batch(
|
||||
self,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
):
|
||||
"""Asynchronously submit and wait until a full batch is ready."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pause(self):
|
||||
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def resume(self):
|
||||
"""Resume request submission for async rollout."""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -16,7 +16,6 @@ from arealite.api.cli_args import GenerationHyperparameters
|
|||
@dataclass
|
||||
class LLMRequest:
|
||||
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text: Optional[str] = None
|
||||
input_ids: List[int] = field(default_factory=list)
|
||||
gconfig: GenerationHyperparameters = field(
|
||||
default_factory=GenerationHyperparameters
|
||||
|
@ -28,7 +27,6 @@ class LLMRequest:
|
|||
@dataclass
|
||||
class LLMResponse:
|
||||
# outputs
|
||||
completions: str
|
||||
input_tokens: List[int] = field(default_factory=list)
|
||||
output_tokens: List[int] = field(default_factory=list)
|
||||
output_logprobs: List[float] = field(default_factory=list)
|
||||
|
|
|
@ -130,11 +130,6 @@ class FSDPEngine(TrainEngine):
|
|||
)
|
||||
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
|
||||
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
|
||||
# Simple auto wrap policy
|
||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||
param_dtype=dtype,
|
||||
|
@ -318,7 +313,9 @@ class FSDPEngine(TrainEngine):
|
|||
self.config.trial_name,
|
||||
meta.model_version,
|
||||
)
|
||||
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
|
||||
name_resolve.add(
|
||||
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||
|
||||
|
|
|
@ -1,23 +1,30 @@
|
|||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import torch.distributed as dist
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import InferenceEngineConfig
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
RolloutStat,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.padding import concat_padded_tensors
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -30,7 +37,7 @@ if pkg_version.is_available("sglang"):
|
|||
else:
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.4
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.1
|
||||
RID_CACHE_SIZE = 128
|
||||
|
||||
|
||||
|
@ -46,22 +53,51 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
# Maintain the addresses for the recent 128 requests
|
||||
self.rid_queue = []
|
||||
|
||||
self.addresses = config.server_addrs
|
||||
self.server_idx = 0
|
||||
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
|
||||
if not self.addresses:
|
||||
raise RuntimeError("No configured SGLang servers.")
|
||||
logger.info("Waiting for server ready...")
|
||||
for addr in self.addresses:
|
||||
self._wait_for_server(addr)
|
||||
logger.info("Servers are all ready!")
|
||||
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 10
|
||||
self.server_idx = random.randint(0, len(self.addresses) - 1)
|
||||
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 16
|
||||
self.input_queue = Queue(maxsize=qsize)
|
||||
self.output_queue = Queue(maxsize=qsize)
|
||||
self.result_cache = []
|
||||
|
||||
self.exiting = threading.Event()
|
||||
self.paused = threading.Event()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.rollout_stat = RolloutStat()
|
||||
|
||||
self._version = 0
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
|
||||
def _wait_for_server(self, address):
|
||||
base_url = f"http://{address}"
|
||||
tik = time.time()
|
||||
while time.time() - tik < self.config.setup_timeout:
|
||||
if self.check_health(base_url):
|
||||
return
|
||||
time.sleep(1)
|
||||
raise RuntimeError("server launch failed")
|
||||
|
||||
def check_health(self, base_url):
|
||||
# Check server endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException as e:
|
||||
return False
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
||||
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||
self.rollout_thread.start()
|
||||
|
||||
|
@ -85,79 +121,45 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
traceback.print_exc()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
data = None
|
||||
|
||||
rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
pending_data = []
|
||||
rollout_tasks = self.rollout_tasks
|
||||
rid = 0
|
||||
|
||||
try:
|
||||
while not self.exiting.is_set():
|
||||
# Load next data from controller
|
||||
if data is None:
|
||||
while True:
|
||||
try:
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.info(f"Get data from puller: {data}")
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
pending_data.append(data)
|
||||
except Empty:
|
||||
logger.debug(f"No data from puller stream.")
|
||||
break
|
||||
|
||||
# Check capacity
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
cannot_rollout_reason = []
|
||||
capacity = max(1, self.config.max_concurrent_rollouts // world_size)
|
||||
can_rollout = len(rollout_tasks) < capacity
|
||||
if not can_rollout:
|
||||
cannot_rollout_reason.append(
|
||||
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
|
||||
)
|
||||
|
||||
# Staleness control
|
||||
version = self.get_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
expected_version = sample_cnt // self.config.consumer_batch_size
|
||||
not_staled = expected_version <= ofp + version
|
||||
can_rollout &= not_staled
|
||||
if not not_staled:
|
||||
cannot_rollout_reason.append(
|
||||
f"Staled: expected version ({expected_version}) = "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({self.config.consumer_batch_size}), "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}."
|
||||
)
|
||||
|
||||
if not can_rollout:
|
||||
logger.debug(
|
||||
f"Cannot submit new rollouts. "
|
||||
+ "\n".join(cannot_rollout_reason)
|
||||
)
|
||||
|
||||
capacity = self.get_capacity()
|
||||
# Create new rollout task
|
||||
if can_rollout and data is not None:
|
||||
while capacity > 0 and pending_data and not self.paused.is_set():
|
||||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self, data), name=str(rid)
|
||||
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
|
||||
)
|
||||
rollout_tasks[str(rid)] = task
|
||||
|
||||
with self.lock:
|
||||
rollout_tasks[str(rid)] = task
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
capacity -= 1
|
||||
rid += 1
|
||||
data = None
|
||||
|
||||
# Wait for rollout completion
|
||||
tasks = list(rollout_tasks.values())
|
||||
with self.lock:
|
||||
tasks = list(rollout_tasks.values())
|
||||
done = []
|
||||
if tasks:
|
||||
done, _ = await asyncio.wait(
|
||||
|
@ -165,16 +167,19 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if not done:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Collect done results
|
||||
for task in done:
|
||||
traj = await task
|
||||
traj: TensorDict
|
||||
task_rid = task.get_name()
|
||||
rollout_tasks.pop(task_rid)
|
||||
self.rollout_stat.accepted += 1
|
||||
with self.lock:
|
||||
rollout_tasks.pop(task_rid)
|
||||
self.rollout_stat.accepted += 1
|
||||
|
||||
try:
|
||||
self.output_queue.put_nowait(traj)
|
||||
|
@ -185,21 +190,25 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
with self.lock:
|
||||
self.rollout_stat.running -= 1
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Cancel remaining tasks
|
||||
for task in rollout_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
with self.lock:
|
||||
for task in rollout_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def choose_server(self) -> str:
|
||||
if self.config.schedule_policy == "round_robin":
|
||||
|
@ -236,8 +245,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=timeout,
|
||||
sock_connect=30,
|
||||
sock_read=timeout,
|
||||
sock_connect=timeout,
|
||||
)
|
||||
) as session:
|
||||
if method.upper() == "GET":
|
||||
|
@ -252,7 +260,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
return await response.json()
|
||||
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
|
@ -288,15 +296,11 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
# NOTE: rid should NOT be passed in payload
|
||||
payload = {
|
||||
"text": req.text,
|
||||
"input_ids": req.input_ids.copy(),
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
"stream": False,
|
||||
}
|
||||
if req.text:
|
||||
payload["text"] = req.text
|
||||
else:
|
||||
payload["input_ids"] = req.input_ids
|
||||
|
||||
# Make request
|
||||
start_time = time.perf_counter()
|
||||
|
@ -324,7 +328,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
||||
):
|
||||
# loop until the generation is complete
|
||||
response = await self.arequest_with_retry(
|
||||
result = await self.arequest_with_retry(
|
||||
endpoint="/generate",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
|
@ -332,10 +336,8 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=self.config.request_timeout,
|
||||
target_addr=server_addr,
|
||||
)
|
||||
result = await response.json()
|
||||
|
||||
# Parse response
|
||||
completions += result["text"]
|
||||
meta_info = result["meta_info"]
|
||||
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
||||
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
||||
|
@ -350,12 +352,11 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["text"] += result["text"]
|
||||
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
completions=completions,
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
|
@ -376,10 +377,10 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name, self.config.trial_name, meta.model_version
|
||||
)
|
||||
save_timestamp = int(name_resolve.wait(update_name, timeout=120))
|
||||
load_timestamp = time.time_ns()
|
||||
save_timestamp = float(name_resolve.wait(update_name, timeout=120))
|
||||
load_timestamp = datetime.now().timestamp()
|
||||
logger.info(
|
||||
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp)/1e6:.2f} ms"
|
||||
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp):.2f}s"
|
||||
)
|
||||
try:
|
||||
jobs = [
|
||||
|
@ -393,14 +394,14 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
finally:
|
||||
loop.close()
|
||||
logger.info(
|
||||
f"Loading weights done in {(time.time_ns() - load_timestamp)/1e6:.2f} ms"
|
||||
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
|
||||
)
|
||||
self.set_version(meta.model_version)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||
|
||||
async def aupdate_weights_from_disk(self, addr, path: str):
|
||||
response = await self.arequest_with_retry(
|
||||
res = await self.arequest_with_retry(
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
||||
method="POST",
|
||||
|
@ -408,7 +409,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=self.config.request_timeout,
|
||||
target_addr=addr,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
|
@ -416,15 +416,40 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
f"during updating weights for server {addr}"
|
||||
)
|
||||
|
||||
def get_capacity(self):
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
max_concurrent_rollouts = max(
|
||||
1, self.config.max_concurrent_rollouts // world_size
|
||||
)
|
||||
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
|
||||
# Staleness control
|
||||
version = self.get_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
|
||||
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
|
||||
return capacity
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||
try:
|
||||
self.input_queue.put_nowait((data, workflow))
|
||||
except Full:
|
||||
raise RuntimeError("Input queue full. Please increase queue_size.")
|
||||
|
||||
def wait(self, count: int, timeout: float, should_accept: Callable) -> TensorDict:
|
||||
def wait(
|
||||
self,
|
||||
count: int,
|
||||
timeout: float | None = None,
|
||||
should_accept: Callable | None = None,
|
||||
) -> TensorDict:
|
||||
tik = time.perf_counter()
|
||||
accepted = len(self.result_cache)
|
||||
timeout = timeout or float(7 * 24 * 3600)
|
||||
while (
|
||||
accepted < count
|
||||
and not self.exiting.is_set()
|
||||
|
@ -432,14 +457,14 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
):
|
||||
try:
|
||||
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
||||
if should_accept(result):
|
||||
if should_accept is None or should_accept(result):
|
||||
self.result_cache.append(result)
|
||||
accepted += 1
|
||||
else:
|
||||
with self.lock:
|
||||
self.rollout_stat.accepted -= 1
|
||||
except Empty:
|
||||
time.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
pass
|
||||
if self.exiting.is_set():
|
||||
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
||||
if accepted < count:
|
||||
|
@ -450,16 +475,39 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.result_cache[:count],
|
||||
self.result_cache[count:],
|
||||
)
|
||||
return TensorDict.cat(results, dim=0)
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
def rollout(
|
||||
def rollout_batch(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
for item in data:
|
||||
self.submit(item, workflow)
|
||||
return self.wait(
|
||||
count=len(data),
|
||||
timeout=self.config.request_timeout,
|
||||
should_accept=lambda x: True,
|
||||
)
|
||||
return self.wait(count=len(data))
|
||||
|
||||
def prepare_batch(
|
||||
self,
|
||||
data_generator: Iterator,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
):
|
||||
assert dataloader.batch_size is not None
|
||||
while True:
|
||||
if self.get_capacity() + dataloader.batch_size > 0:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(dataloader)
|
||||
data = next(data_generator)
|
||||
for item in data:
|
||||
self.submit(item, workflow=workflow)
|
||||
try:
|
||||
return self.wait(dataloader.batch_size, timeout=1)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
def pause(self):
|
||||
self.paused.set()
|
||||
|
||||
def resume(self):
|
||||
self.paused.clear()
|
||||
|
|
|
@ -5,7 +5,6 @@ import time
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
@ -15,62 +14,43 @@ from arealite.api.cli_args import (
|
|||
SGLangConfig,
|
||||
)
|
||||
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
|
||||
from arealite.utils import network
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import network
|
||||
|
||||
EXPR_NAME = "test_sglang_engine"
|
||||
TRIAL_NAME = "trial_0"
|
||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
MODEL_PATH = "Qwen/Qwen2-0.5B"
|
||||
PORT = 13887
|
||||
DIST_PORT = 15887
|
||||
PORT, DIST_PORT = network.find_free_ports(2)
|
||||
HOST = network.gethostip()
|
||||
|
||||
|
||||
def check_server_health(base_url):
|
||||
# Check server endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server():
|
||||
from realhf.base import seeding
|
||||
|
||||
seeding.set_random_seed(1, EXPR_NAME)
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
sglang_config=SGLangConfig(mem_fraction_static=0.3),
|
||||
model_path=MODEL_PATH,
|
||||
sglang_config=SGLangConfig(
|
||||
skip_tokenizer_init=True,
|
||||
model_path=MODEL_PATH,
|
||||
mem_fraction_static=0.3,
|
||||
),
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
tp_size=1,
|
||||
base_gpu_id=0,
|
||||
dist_init_addr=f"{HOST}:{DIST_PORT}",
|
||||
served_model_name=MODEL_PATH,
|
||||
skip_tokenizer_init=False,
|
||||
)
|
||||
# Launch process
|
||||
full_command = f"{cmd} --port {PORT}"
|
||||
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
|
||||
cmd = cmd.replace("\\\n", " ").replace("\\", " ")
|
||||
process = subprocess.Popen(
|
||||
full_command.split(),
|
||||
cmd.split(),
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
base_url = f"http://{HOST}:{PORT}"
|
||||
tik = time.time()
|
||||
while time.time() - tik < 90:
|
||||
if check_server_health(base_url):
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - tik > 90:
|
||||
raise RuntimeError("server launch failed")
|
||||
yield
|
||||
process.terminate()
|
||||
|
||||
|
@ -80,11 +60,12 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
tokenizer = load_hf_tokenizer(MODEL_PATH)
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
req = LLMRequest(
|
||||
rid=str(uuid.uuid4()),
|
||||
text="hello! how are you today",
|
||||
input_ids=tokenizer.encode("hello! how are you today"),
|
||||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
|
@ -95,7 +76,6 @@ async def test_remote_sglang_generate(sglang_server):
|
|||
== len(resp.output_tokens)
|
||||
== len(resp.output_versions)
|
||||
)
|
||||
assert isinstance(resp.completions, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||
|
@ -109,7 +89,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
|
|||
max_concurrent_rollouts=2,
|
||||
consumer_batch_size=2,
|
||||
)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
engine.initialize(None, None)
|
||||
|
||||
|
@ -122,12 +102,13 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
|
|||
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
|
||||
gconfig=gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
}
|
||||
result = engine.rollout([data] * 2, workflow=workflow)
|
||||
result = engine.rollout_batch([data] * 2, workflow=workflow)
|
||||
assert isinstance(result, TensorDict)
|
||||
bs = result.batch_size
|
||||
assert bs == torch.Size([2 * n_samples])
|
||||
|
@ -147,7 +128,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
consumer_batch_size=bs,
|
||||
max_head_offpolicyness=ofp,
|
||||
)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
engine.initialize(None, None)
|
||||
|
||||
|
@ -160,6 +141,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
|
||||
gconfig=gconfig,
|
||||
tokenizer=tokenizer,
|
||||
enable_thinking=False,
|
||||
)
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
|
@ -220,7 +202,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
|||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
inf_engine = RemoteSGLangEngine(config)
|
||||
# test update weights
|
||||
path = tmp_path_factory.mktemp("upload_weights_from_disk")
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
import random
|
||||
import socket
|
||||
from typing import List, Set
|
||||
|
||||
|
||||
def gethostname():
|
||||
return socket.gethostname()
|
||||
|
||||
|
||||
def gethostip():
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
|
||||
|
||||
def find_free_ports(
|
||||
count: int, port_range: tuple = (1024, 65535), exclude_ports: Set[int] | None = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Find multiple free ports within a specified range.
|
||||
|
||||
Args:
|
||||
count: Number of free ports to find
|
||||
port_range: Tuple of (min_port, max_port) to search within
|
||||
exclude_ports: Set of ports to exclude from search
|
||||
|
||||
Returns:
|
||||
List of free port numbers
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to find requested number of free ports
|
||||
"""
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set()
|
||||
|
||||
min_port, max_port = port_range
|
||||
free_ports = []
|
||||
attempted_ports = set()
|
||||
|
||||
# Calculate available port range
|
||||
available_range = max_port - min_port + 1 - len(exclude_ports)
|
||||
|
||||
if count > available_range:
|
||||
raise ValueError(
|
||||
f"Cannot find {count} ports in range {port_range}. "
|
||||
f"Only {available_range} ports available."
|
||||
)
|
||||
|
||||
max_attempts = count * 10 # Reasonable limit to avoid infinite loops
|
||||
attempts = 0
|
||||
|
||||
while len(free_ports) < count and attempts < max_attempts:
|
||||
# Generate random port within range
|
||||
port = random.randint(min_port, max_port)
|
||||
|
||||
# Skip if port already attempted or excluded
|
||||
if port in attempted_ports or port in exclude_ports:
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
attempted_ports.add(port)
|
||||
|
||||
if is_port_free(port):
|
||||
free_ports.append(port)
|
||||
|
||||
attempts += 1
|
||||
|
||||
if len(free_ports) < count:
|
||||
raise ValueError(
|
||||
f"Could only find {len(free_ports)} free ports "
|
||||
f"out of {count} requested after {max_attempts} attempts"
|
||||
)
|
||||
|
||||
return sorted(free_ports)
|
||||
|
||||
|
||||
def is_port_free(port: int) -> bool:
|
||||
"""
|
||||
Check if a port is free by attempting to bind to it.
|
||||
|
||||
Args:
|
||||
port: Port number to check
|
||||
|
||||
Returns:
|
||||
True if port is free, False otherwise
|
||||
"""
|
||||
# Check TCP
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.bind(("", port))
|
||||
sock.close()
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
# Check UDP
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
sock.bind(("", port))
|
||||
sock.close()
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
|
@ -17,19 +17,24 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
enable_thinking: bool,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
self.enable_thinking = enable_thinking
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
data["messages"], tokenize=False, add_generation_prompt=True
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
data["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=self.enable_thinking,
|
||||
)
|
||||
n_samples = self.gconfig.n_samples
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
text=text,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
@ -42,8 +47,8 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
reward = self.reward_fn(
|
||||
prompt=req.text,
|
||||
completions=resp.completions,
|
||||
prompt=self.tokenizer.decode(input_ids),
|
||||
completions=self.tokenizer.decode(resp.output_tokens),
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
|
|
Loading…
Reference in New Issue