mirror of https://github.com/inclusionAI/AReaL
0724_merge1
This commit is contained in:
parent
6c28d52387
commit
c816a3ce44
|
@ -697,7 +697,7 @@ class DatasetConfig:
|
|||
},
|
||||
)
|
||||
type: Optional[str] = field(
|
||||
default=None, metadata={"help": "Type of implemented dataset"}
|
||||
default=None, metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."}
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=1, metadata={"help": "Batch size of the dataloader"}
|
||||
|
|
|
@ -9,38 +9,35 @@ def get_custom_dataset(
|
|||
path: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
training_type: str = "sft",
|
||||
type: str = "sft",
|
||||
split: Optional[str] = None,
|
||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
||||
processor: Optional[transformers.AutoProcessor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if "gsm8k" in path and training_type == "sft":
|
||||
if "gsm8k" in path and type == "sft":
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
|
||||
|
||||
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
|
||||
elif "gsm8k" in path and training_type == "rl":
|
||||
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size, **kwargs)
|
||||
elif "gsm8k" in path and type == "rl":
|
||||
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
|
||||
|
||||
return get_gsm8k_rl_dataset(path, split, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "sft":
|
||||
return get_gsm8k_rl_dataset(path, split, rank, world_size, **kwargs)
|
||||
elif "clevr_count_70k" in path and type == "sft":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_sft_dataset,
|
||||
)
|
||||
|
||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
|
||||
elif "clevr_count_70k" in path and training_type == "rl":
|
||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size, **kwargs)
|
||||
elif "clevr_count_70k" in path and type == "rl":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_rl_dataset,
|
||||
)
|
||||
|
||||
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
|
||||
elif "geometry3k" in path and training_type == "rl":
|
||||
from examples.arealite.dataset.geometry3k import get_geometry3k_rl_dataset
|
||||
|
||||
return get_geometry3k_rl_dataset(path, split, processor, rank, world_size)
|
||||
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
|
||||
f"Dataset {path} with split {split} and training type {type} is not supported. "
|
||||
f"Supported datasets are: {VALID_DATASETS}. "
|
||||
)
|
||||
|
|
|
@ -96,7 +96,14 @@ class BaseHFEngine(TrainEngine):
|
|||
dtype = getattr(torch, self.config.dtype)
|
||||
|
||||
if self.is_vision_model:
|
||||
dtype = torch.bfloat16
|
||||
if dtype== torch.float16:
|
||||
raise ValueError(
|
||||
"Vision models do not support float16 dtype. Please use bfloat16."
|
||||
)
|
||||
if self.config.init_from_scratch:
|
||||
raise ValueError(
|
||||
"Vision models do not support initialization from scratch. Please use a pretrained model."
|
||||
)
|
||||
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
|
||||
self.config.path
|
||||
)
|
||||
|
@ -304,6 +311,7 @@ class BaseHFEngine(TrainEngine):
|
|||
loss_scale *= self.world_size
|
||||
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import torch
|
||||
|
||||
|
||||
|
||||
VALID_VISION_MODELS = [
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
]
|
||||
# This registry is used to check if a model is a vision model that we have checked it works with AReaLite. As different vision models vary in their image processing, special tokens and keys, etc. We will add models to this registry as we test them.
|
||||
# If you want to add a new vision model, please make sure it works with AReaLite.
|
||||
|
||||
|
||||
# Copied from trl
|
||||
|
|
|
@ -12,7 +12,7 @@ from arealite.utils.image import image2base64
|
|||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
|
||||
|
||||
class VL_RLVRWorkflow(RLVRWorkflow):
|
||||
class VisionRLVRWorkflow(RLVRWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
reward_fn,
|
||||
|
@ -67,14 +67,8 @@ class VL_RLVRWorkflow(RLVRWorkflow):
|
|||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||
pixel_values=processed_input["pixel_values"]
|
||||
.clone()
|
||||
.detach()
|
||||
.unsqueeze(0),
|
||||
image_grid_thw=processed_input["image_grid_thw"]
|
||||
.clone()
|
||||
.detach()
|
||||
.unsqueeze(0),
|
||||
pixel_values=processed_input["pixel_values"].unsqueeze(0),
|
||||
image_grid_thw=processed_input["image_grid_thw"].unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
|
@ -16,11 +16,10 @@ from arealite.utils.device import log_gpu_stats
|
|||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from arealite.workflow.vl_rlvr import VL_RLVRWorkflow
|
||||
from AReaL.arealite.workflow.Visionrlvr import VisionRLVRWorkflow
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||
if match:
|
||||
|
@ -28,35 +27,6 @@ def extract_answer(pred_str, data_name, use_last_number=True):
|
|||
|
||||
return ""
|
||||
|
||||
|
||||
# Adapted from verl.
|
||||
def extract_solution(solution_str, method="strict") -> str | None:
|
||||
assert method in ["strict", "flexible"]
|
||||
|
||||
final_answer = None
|
||||
if method == "strict":
|
||||
# this also tests the formatting of the model
|
||||
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
||||
if len(solutions) == 0:
|
||||
final_answer = None
|
||||
else:
|
||||
# take the last solution
|
||||
final_answer = solutions[-1].replace(",", "").replace("$", "")
|
||||
elif method == "flexible":
|
||||
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
|
||||
final_answer = None
|
||||
if len(answer) == 0:
|
||||
# no reward is there is no answer
|
||||
pass
|
||||
else:
|
||||
invalid_str = ["", "."]
|
||||
# find the last number that is not '.'
|
||||
for final_answer in reversed(answer):
|
||||
if final_answer not in invalid_str:
|
||||
break
|
||||
return final_answer
|
||||
|
||||
|
||||
def clevr_count_70k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
|
@ -75,14 +45,13 @@ def clevr_count_70k_reward_fn(
|
|||
if is_thinking:
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
return 0.8
|
||||
|
||||
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
|
||||
return 0.05
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main(args):
|
||||
wandb.init(project="clevr_70k")
|
||||
|
||||
|
@ -97,7 +66,7 @@ def main(args):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="rl",
|
||||
training_type=config.train_dataset.type,
|
||||
processor=processor,
|
||||
)
|
||||
valid_dataset = get_custom_dataset(
|
||||
|
@ -105,7 +74,7 @@ def main(args):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="rl",
|
||||
training_type=config.valid_dataset.type,
|
||||
processor=processor,
|
||||
)
|
||||
# Create dataset and dataloaders
|
||||
|
@ -153,7 +122,7 @@ def main(args):
|
|||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
workflow = VL_RLVRWorkflow(
|
||||
workflow = VisionRLVRWorkflow(
|
||||
reward_fn=clevr_count_70k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
|
|
|
@ -27,7 +27,7 @@ def main_sft():
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="train",
|
||||
training_type="sft",
|
||||
training_type=config.train_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ def main_sft():
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
split="test",
|
||||
training_type="sft",
|
||||
training_type=config.valid_dataset.type,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
|
|
@ -98,6 +98,7 @@ train_dataset:
|
|||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
type: rl
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 32
|
||||
|
@ -105,6 +106,7 @@ valid_dataset:
|
|||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
type: rl
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
@ -134,16 +136,4 @@ evaluator:
|
|||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
||||
|
||||
# launcher:
|
||||
# inference_server_cpus_per_gpu: 15
|
||||
# inference_server_mem_per_gpu: 153600
|
||||
# trainer_cpus_per_gpu: 15
|
||||
# trainer_mem_per_gpu: 153600
|
||||
# slurm:
|
||||
# mount: /storage:/storage
|
||||
# trainer_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
||||
# inference_server_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
|
||||
fileroot: ${cluster.fileroot}
|
|
@ -40,6 +40,7 @@ train_dataset:
|
|||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
type: sft
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 128
|
||||
|
@ -47,6 +48,7 @@ valid_dataset:
|
|||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: /storage/openpsi/data/clevr_count_70k/
|
||||
type: sft
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
@ -76,6 +78,4 @@ evaluator:
|
|||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
||||
fileroot: ${cluster.fileroot}
|
|
@ -92,6 +92,7 @@ train_dataset:
|
|||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: openai/gsm8k
|
||||
type: rl
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 256
|
||||
|
@ -99,6 +100,7 @@ valid_dataset:
|
|||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: openai/gsm8k
|
||||
type: rl
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
@ -129,6 +131,4 @@ stats_logger:
|
|||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ train_dataset:
|
|||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: openai/gsm8k
|
||||
type: sft
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 128
|
||||
|
@ -44,6 +45,7 @@ valid_dataset:
|
|||
pin_memory: true
|
||||
num_workers: 4
|
||||
path: openai/gsm8k
|
||||
type: sft
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
@ -74,5 +76,3 @@ stats_logger:
|
|||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
|
@ -31,7 +31,7 @@ dependencies = [
|
|||
"huggingface_hub",
|
||||
"datasets",
|
||||
"accelerate",
|
||||
"transformers==4.53.3",
|
||||
"transformers>=4.53.3",
|
||||
|
||||
# Scientific computing
|
||||
"numpy<2.0.0",
|
||||
|
|
|
@ -303,6 +303,7 @@ class SGLangConfig:
|
|||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
hybrid_train: bool = False
|
||||
dtype: str = "float16"
|
||||
kv_cache_dtype: str = "auto"
|
||||
|
||||
|
@ -317,19 +318,15 @@ class SGLangConfig:
|
|||
# and update prometheus metrics
|
||||
decode_log_interval: int = 1
|
||||
|
||||
# Not used.
|
||||
hybrid_train: bool = False
|
||||
|
||||
# Use staticmethod to make OmegaConf happy.
|
||||
@staticmethod
|
||||
def build_cmd(
|
||||
sglang_config: "SGLangConfig",
|
||||
model_path,
|
||||
tp_size,
|
||||
server_index,
|
||||
base_gpu_id,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
served_model_name: Optional[str] = None,
|
||||
skip_tokenizer_init: bool = True,
|
||||
):
|
||||
from realhf.base import constants, network, pkg_version, seeding
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
@ -338,8 +335,6 @@ class SGLangConfig:
|
|||
args.pop("hybrid_train")
|
||||
args["random_seed"] = seeding.get_seed()
|
||||
|
||||
if served_model_name is None:
|
||||
served_model_name = model_path
|
||||
host_ip = network.gethostip()
|
||||
host = "localhost" if not sglang_config.enable_metrics else host_ip
|
||||
args = dict(
|
||||
|
@ -351,9 +346,9 @@ class SGLangConfig:
|
|||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
device="cuda",
|
||||
served_model_name=served_model_name,
|
||||
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}",
|
||||
is_embedding=False,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
skip_tokenizer_init=True,
|
||||
# Other runtime options
|
||||
tp_size=tp_size,
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
|
@ -561,10 +556,6 @@ class GenerationHyperparameters:
|
|||
default=1.0,
|
||||
metadata={"help": "Sampling temperature. Higher values increase diversity."},
|
||||
)
|
||||
stop_token_ids: List[int] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "Stop generation when encoutering these token ids."},
|
||||
)
|
||||
|
||||
# Deprecated parameters
|
||||
use_cuda_graph: bool = field(
|
||||
|
@ -1564,4 +1555,4 @@ def print_runtime_helper(args):
|
|||
)
|
||||
|
||||
# Configuration options section
|
||||
print_config_values(args)
|
||||
print_config_values(args)
|
|
@ -120,6 +120,7 @@ BASE_ENVIRONS = {
|
|||
"REAL_IS_REMOTE": "1",
|
||||
# "NCCL_P2P_DISABLE": "1",
|
||||
# "NCCL_IB_DISABLE": "1",
|
||||
"TRANSFORMERS_OFFLINE": "1",
|
||||
"TOKENIZERS_PARALLELISM": "true",
|
||||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||
|
|
|
@ -27,16 +27,15 @@ def gpu_count():
|
|||
|
||||
Ad-hoc to frl cluster.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_cnt = torch.cuda.device_count()
|
||||
except ImportError:
|
||||
torch_cnt = 0
|
||||
if platform.system() == "Darwin":
|
||||
return 0
|
||||
elif platform.system() == "Windows":
|
||||
return torch_cnt
|
||||
try:
|
||||
import torch
|
||||
|
||||
return torch.cuda.device_count()
|
||||
except ImportError:
|
||||
return 0
|
||||
else:
|
||||
dev_directories = list(os.listdir("/dev/"))
|
||||
for cnt in itertools.count():
|
||||
|
@ -44,8 +43,7 @@ def gpu_count():
|
|||
continue
|
||||
else:
|
||||
break
|
||||
return cnt or torch_cnt
|
||||
|
||||
return cnt
|
||||
|
||||
def set_cuda_device(device):
|
||||
"""Set the default cuda-device.
|
||||
|
|
|
@ -92,15 +92,6 @@ def stream_pullers(experiment_name, trial_name):
|
|||
def gen_servers(experiment_name, trial_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers"
|
||||
|
||||
|
||||
def gen_server(experiment_name, trial_name, server_id):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server/{server_id}"
|
||||
|
||||
|
||||
def gen_server_root(experiment_name, trial_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server/"
|
||||
|
||||
|
||||
def used_ports(experiment_name, trial_name, host_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/{host_name}/"
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ def find_free_port(
|
|||
trial_name="port",
|
||||
lockfile_root=constants.PORT_LOCKFILE_ROOT,
|
||||
):
|
||||
# TODO: user random sampling instead of bind
|
||||
"""Find a free port within the specified range, excluding certain ports."""
|
||||
|
||||
ports_name = names.used_ports(experiment_name, trial_name, gethostip())
|
||||
|
|
|
@ -64,4 +64,4 @@ def get_trial_name(default_name: str = ""):
|
|||
return trial_name
|
||||
|
||||
|
||||
# global_init()
|
||||
global_init()
|
||||
|
|
|
@ -796,7 +796,7 @@ def loadJson(dataDir):
|
|||
return samples
|
||||
|
||||
|
||||
def parse_line(id2info, generated, query_id):
|
||||
def parse_line(id2info, prompt_str, generated, query_id):
|
||||
info = id2info[query_id.split("@idx:")[0]]
|
||||
|
||||
label = 0
|
||||
|
|
|
@ -298,15 +298,15 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logps(
|
||||
def generate(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample:
|
||||
module = model.module
|
||||
|
||||
module.eval()
|
||||
self.engine.forward()
|
||||
|
||||
# Remap the key `packed_prompts` to `packed_input_ids`,
|
||||
# because the pipe runner only recognizes `packed_input_ids`.
|
||||
|
|
|
@ -170,4 +170,4 @@ def make(args: "BaseExperimentConfig", **kwargs) -> SchedulerClient:
|
|||
|
||||
return LocalSchedulerClient(args)
|
||||
else:
|
||||
raise NotImplementedError(f"Scheduler {args.mode} not found")
|
||||
raise NotImplementedError(f"Scheduler {args.mode} not found")
|
|
@ -25,19 +25,14 @@ class ZMQJsonPusher:
|
|||
hwm: High-water mark for outgoing messages (default: 1000)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, host: str = "localhost", port: int = 5555, hwm: int = 1000, bind=False
|
||||
):
|
||||
def __init__(self, host: str = "localhost", port: int = 5555, hwm: int = 1000):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.ctx = zmq.Context.instance()
|
||||
self.socket = self.ctx.socket(zmq.PUSH)
|
||||
self.socket.setsockopt(zmq.SNDHWM, hwm)
|
||||
if not bind:
|
||||
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||
else:
|
||||
self.socket.bind(f"tcp://{self.host}:{self.port}")
|
||||
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||
|
||||
def push(self, data: JSONType) -> None:
|
||||
"""
|
||||
|
@ -82,7 +77,6 @@ class ZMQJsonPuller:
|
|||
port: int = 5555,
|
||||
default_timeout_ms: int = 1000,
|
||||
hwm: int = 1000,
|
||||
bind: bool = True,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
@ -92,10 +86,7 @@ class ZMQJsonPuller:
|
|||
self.socket = self.ctx.socket(zmq.PULL)
|
||||
self.socket.setsockopt(zmq.RCVHWM, hwm)
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, self.default_timeout_ms)
|
||||
if bind:
|
||||
self.socket.bind(f"tcp://{self.host}:{self.port}")
|
||||
else:
|
||||
self.socket.connect(f"tcp://{self.host}:{self.port}")
|
||||
self.socket.bind(f"tcp://{self.host}:{self.port}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
|
@ -183,4 +174,4 @@ class NameResolvingZmqPuller(ZMQJsonPuller):
|
|||
)
|
||||
addr = f"{host}:{port}"
|
||||
name_resolve.add(name, addr)
|
||||
super().__init__(host, port, **kwargs)
|
||||
super().__init__(host, port, **kwargs)
|
Loading…
Reference in New Issue