0724_merge1

This commit is contained in:
朱晗 2025-07-24 14:18:42 +08:00
parent 6c28d52387
commit c816a3ce44
23 changed files with 68 additions and 135 deletions

View File

@ -697,7 +697,7 @@ class DatasetConfig:
},
)
type: Optional[str] = field(
default=None, metadata={"help": "Type of implemented dataset"}
default=None, metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."}
)
batch_size: int = field(
default=1, metadata={"help": "Batch size of the dataloader"}

View File

@ -9,38 +9,35 @@ def get_custom_dataset(
path: str,
rank: int,
world_size: int,
training_type: str = "sft",
type: str = "sft",
split: Optional[str] = None,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
processor: Optional[transformers.AutoProcessor] = None,
**kwargs,
):
if "gsm8k" in path and training_type == "sft":
if "gsm8k" in path and type == "sft":
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
elif "gsm8k" in path and training_type == "rl":
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size, **kwargs)
elif "gsm8k" in path and type == "rl":
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
return get_gsm8k_rl_dataset(path, split, rank, world_size)
elif "clevr_count_70k" in path and training_type == "sft":
return get_gsm8k_rl_dataset(path, split, rank, world_size, **kwargs)
elif "clevr_count_70k" in path and type == "sft":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_sft_dataset,
)
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
elif "clevr_count_70k" in path and training_type == "rl":
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size, **kwargs)
elif "clevr_count_70k" in path and type == "rl":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset,
)
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
elif "geometry3k" in path and training_type == "rl":
from examples.arealite.dataset.geometry3k import get_geometry3k_rl_dataset
return get_geometry3k_rl_dataset(path, split, processor, rank, world_size)
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size, **kwargs)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
f"Dataset {path} with split {split} and training type {type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. "
)

View File

@ -96,7 +96,14 @@ class BaseHFEngine(TrainEngine):
dtype = getattr(torch, self.config.dtype)
if self.is_vision_model:
dtype = torch.bfloat16
if dtype== torch.float16:
raise ValueError(
"Vision models do not support float16 dtype. Please use bfloat16."
)
if self.config.init_from_scratch:
raise ValueError(
"Vision models do not support initialization from scratch. Please use a pretrained model."
)
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
self.config.path
)
@ -304,6 +311,7 @@ class BaseHFEngine(TrainEngine):
loss_scale *= self.world_size
loss *= loss_scale
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),

View File

@ -1,9 +1,13 @@
import torch
VALID_VISION_MODELS = [
"qwen2_vl",
"qwen2_5_vl",
]
# This registry is used to check if a model is a vision model that we have checked it works with AReaLite. As different vision models vary in their image processing, special tokens and keys, etc. We will add models to this registry as we test them.
# If you want to add a new vision model, please make sure it works with AReaLite.
# Copied from trl

View File

@ -12,7 +12,7 @@ from arealite.utils.image import image2base64
from arealite.workflow.rlvr import RLVRWorkflow
class VL_RLVRWorkflow(RLVRWorkflow):
class VisionRLVRWorkflow(RLVRWorkflow):
def __init__(
self,
reward_fn,
@ -67,14 +67,8 @@ class VL_RLVRWorkflow(RLVRWorkflow):
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
pixel_values=processed_input["pixel_values"]
.clone()
.detach()
.unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"]
.clone()
.detach()
.unsqueeze(0),
pixel_values=processed_input["pixel_values"].unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"].unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),

View File

@ -16,11 +16,10 @@ from arealite.utils.device import log_gpu_stats
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from arealite.workflow.vl_rlvr import VL_RLVRWorkflow
from AReaL.arealite.workflow.Visionrlvr import VisionRLVRWorkflow
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
@ -28,35 +27,6 @@ def extract_answer(pred_str, data_name, use_last_number=True):
return ""
# Adapted from verl.
def extract_solution(solution_str, method="strict") -> str | None:
assert method in ["strict", "flexible"]
final_answer = None
if method == "strict":
# this also tests the formatting of the model
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
if len(solutions) == 0:
final_answer = None
else:
# take the last solution
final_answer = solutions[-1].replace(",", "").replace("$", "")
elif method == "flexible":
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ["", "."]
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
@ -75,14 +45,13 @@ def clevr_count_70k_reward_fn(
if is_thinking:
return 1
else:
return 1
return 0.8
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
return 0.05
return 0
def main(args):
wandb.init(project="clevr_70k")
@ -97,7 +66,7 @@ def main(args):
rank=rank,
world_size=world_size,
split="train",
training_type="rl",
training_type=config.train_dataset.type,
processor=processor,
)
valid_dataset = get_custom_dataset(
@ -105,7 +74,7 @@ def main(args):
rank=rank,
world_size=world_size,
split="test",
training_type="rl",
training_type=config.valid_dataset.type,
processor=processor,
)
# Create dataset and dataloaders
@ -153,7 +122,7 @@ def main(args):
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = VL_RLVRWorkflow(
workflow = VisionRLVRWorkflow(
reward_fn=clevr_count_70k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,

View File

@ -27,7 +27,7 @@ def main_sft():
rank=rank,
world_size=world_size,
split="train",
training_type="sft",
training_type=config.train_dataset.type,
tokenizer=tokenizer,
processor=processor,
)
@ -36,7 +36,7 @@ def main_sft():
rank=rank,
world_size=world_size,
split="test",
training_type="sft",
training_type=config.valid_dataset.type,
tokenizer=tokenizer,
processor=processor,
)

View File

@ -98,6 +98,7 @@ train_dataset:
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
type: rl
valid_dataset:
batch_size: 32
@ -105,6 +106,7 @@ valid_dataset:
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
type: rl
# Utilities
saver:
@ -134,16 +136,4 @@ evaluator:
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled
# launcher:
# inference_server_cpus_per_gpu: 15
# inference_server_mem_per_gpu: 153600
# trainer_cpus_per_gpu: 15
# trainer_mem_per_gpu: 153600
# slurm:
# mount: /storage:/storage
# trainer_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
# inference_server_image: /storage/openpsi/images/areal-v0.3.0.post1.sif
fileroot: ${cluster.fileroot}

View File

@ -40,6 +40,7 @@ train_dataset:
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
type: sft
valid_dataset:
batch_size: 128
@ -47,6 +48,7 @@ valid_dataset:
pin_memory: true
num_workers: 4
path: /storage/openpsi/data/clevr_count_70k/
type: sft
# Utilities
saver:
@ -76,6 +78,4 @@ evaluator:
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled
fileroot: ${cluster.fileroot}

View File

@ -92,6 +92,7 @@ train_dataset:
pin_memory: true
num_workers: 4
path: openai/gsm8k
type: rl
valid_dataset:
batch_size: 256
@ -99,6 +100,7 @@ valid_dataset:
pin_memory: true
num_workers: 4
path: openai/gsm8k
type: rl
# Utilities
saver:
@ -129,6 +131,4 @@ stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -37,6 +37,7 @@ train_dataset:
pin_memory: true
num_workers: 4
path: openai/gsm8k
type: sft
valid_dataset:
batch_size: 128
@ -44,6 +45,7 @@ valid_dataset:
pin_memory: true
num_workers: 4
path: openai/gsm8k
type: sft
# Utilities
saver:
@ -74,5 +76,3 @@ stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -31,7 +31,7 @@ dependencies = [
"huggingface_hub",
"datasets",
"accelerate",
"transformers==4.53.3",
"transformers>=4.53.3",
# Scientific computing
"numpy<2.0.0",

View File

@ -303,6 +303,7 @@ class SGLangConfig:
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
hybrid_train: bool = False
dtype: str = "float16"
kv_cache_dtype: str = "auto"
@ -317,19 +318,15 @@ class SGLangConfig:
# and update prometheus metrics
decode_log_interval: int = 1
# Not used.
hybrid_train: bool = False
# Use staticmethod to make OmegaConf happy.
@staticmethod
def build_cmd(
sglang_config: "SGLangConfig",
model_path,
tp_size,
server_index,
base_gpu_id,
dist_init_addr: Optional[str] = None,
served_model_name: Optional[str] = None,
skip_tokenizer_init: bool = True,
):
from realhf.base import constants, network, pkg_version, seeding
from realhf.experiments.common.utils import asdict as conf_as_dict
@ -338,8 +335,6 @@ class SGLangConfig:
args.pop("hybrid_train")
args["random_seed"] = seeding.get_seed()
if served_model_name is None:
served_model_name = model_path
host_ip = network.gethostip()
host = "localhost" if not sglang_config.enable_metrics else host_ip
args = dict(
@ -351,9 +346,9 @@ class SGLangConfig:
load_format="auto",
trust_remote_code=True,
device="cuda",
served_model_name=served_model_name,
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}",
is_embedding=False,
skip_tokenizer_init=skip_tokenizer_init,
skip_tokenizer_init=True,
# Other runtime options
tp_size=tp_size,
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
@ -561,10 +556,6 @@ class GenerationHyperparameters:
default=1.0,
metadata={"help": "Sampling temperature. Higher values increase diversity."},
)
stop_token_ids: List[int] = field(
default_factory=list,
metadata={"help": "Stop generation when encoutering these token ids."},
)
# Deprecated parameters
use_cuda_graph: bool = field(
@ -1564,4 +1555,4 @@ def print_runtime_helper(args):
)
# Configuration options section
print_config_values(args)
print_config_values(args)

View File

@ -120,6 +120,7 @@ BASE_ENVIRONS = {
"REAL_IS_REMOTE": "1",
# "NCCL_P2P_DISABLE": "1",
# "NCCL_IB_DISABLE": "1",
"TRANSFORMERS_OFFLINE": "1",
"TOKENIZERS_PARALLELISM": "true",
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,

View File

@ -27,16 +27,15 @@ def gpu_count():
Ad-hoc to frl cluster.
"""
try:
import torch
torch_cnt = torch.cuda.device_count()
except ImportError:
torch_cnt = 0
if platform.system() == "Darwin":
return 0
elif platform.system() == "Windows":
return torch_cnt
try:
import torch
return torch.cuda.device_count()
except ImportError:
return 0
else:
dev_directories = list(os.listdir("/dev/"))
for cnt in itertools.count():
@ -44,8 +43,7 @@ def gpu_count():
continue
else:
break
return cnt or torch_cnt
return cnt
def set_cuda_device(device):
"""Set the default cuda-device.

View File

@ -92,15 +92,6 @@ def stream_pullers(experiment_name, trial_name):
def gen_servers(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers"
def gen_server(experiment_name, trial_name, server_id):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server/{server_id}"
def gen_server_root(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server/"
def used_ports(experiment_name, trial_name, host_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/{host_name}/"

View File

@ -30,7 +30,6 @@ def find_free_port(
trial_name="port",
lockfile_root=constants.PORT_LOCKFILE_ROOT,
):
# TODO: user random sampling instead of bind
"""Find a free port within the specified range, excluding certain ports."""
ports_name = names.used_ports(experiment_name, trial_name, gethostip())

View File

@ -64,4 +64,4 @@ def get_trial_name(default_name: str = ""):
return trial_name
# global_init()
global_init()

0
realhf/experiments/async_exp/async_rl_exp.py Normal file → Executable file
View File

View File

@ -796,7 +796,7 @@ def loadJson(dataDir):
return samples
def parse_line(id2info, generated, query_id):
def parse_line(id2info, prompt_str, generated, query_id):
info = id2info[query_id.split("@idx:")[0]]
label = 0

View File

@ -298,15 +298,15 @@ class PPOActorInterface(model_api.ModelInterface):
)
@torch.no_grad()
def compute_logps(
def generate(
self,
model: model_api.Model,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample:
module = model.module
module.eval()
self.engine.forward()
# Remap the key `packed_prompts` to `packed_input_ids`,
# because the pipe runner only recognizes `packed_input_ids`.

View File

@ -170,4 +170,4 @@ def make(args: "BaseExperimentConfig", **kwargs) -> SchedulerClient:
return LocalSchedulerClient(args)
else:
raise NotImplementedError(f"Scheduler {args.mode} not found")
raise NotImplementedError(f"Scheduler {args.mode} not found")

View File

@ -25,19 +25,14 @@ class ZMQJsonPusher:
hwm: High-water mark for outgoing messages (default: 1000)
"""
def __init__(
self, host: str = "localhost", port: int = 5555, hwm: int = 1000, bind=False
):
def __init__(self, host: str = "localhost", port: int = 5555, hwm: int = 1000):
self.host = host
self.port = port
self.ctx = zmq.Context.instance()
self.socket = self.ctx.socket(zmq.PUSH)
self.socket.setsockopt(zmq.SNDHWM, hwm)
if not bind:
self.socket.connect(f"tcp://{self.host}:{self.port}")
else:
self.socket.bind(f"tcp://{self.host}:{self.port}")
self.socket.connect(f"tcp://{self.host}:{self.port}")
def push(self, data: JSONType) -> None:
"""
@ -82,7 +77,6 @@ class ZMQJsonPuller:
port: int = 5555,
default_timeout_ms: int = 1000,
hwm: int = 1000,
bind: bool = True,
):
self.host = host
self.port = port
@ -92,10 +86,7 @@ class ZMQJsonPuller:
self.socket = self.ctx.socket(zmq.PULL)
self.socket.setsockopt(zmq.RCVHWM, hwm)
self.socket.setsockopt(zmq.RCVTIMEO, self.default_timeout_ms)
if bind:
self.socket.bind(f"tcp://{self.host}:{self.port}")
else:
self.socket.connect(f"tcp://{self.host}:{self.port}")
self.socket.bind(f"tcp://{self.host}:{self.port}")
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
@ -183,4 +174,4 @@ class NameResolvingZmqPuller(ZMQJsonPuller):
)
addr = f"{host}:{port}"
name_resolve.add(name, addr)
super().__init__(host, port, **kwargs)
super().__init__(host, port, **kwargs)