0724_merge3

This commit is contained in:
朱晗 2025-07-24 14:29:29 +08:00
parent e97e33fca8
commit 176ec4bb23
11 changed files with 20 additions and 13 deletions

View File

@ -697,7 +697,8 @@ class DatasetConfig:
},
)
type: Optional[str] = field(
default=None, metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."}
default=None,
metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."},
)
batch_size: int = field(
default=1, metadata={"help": "Batch size of the dataloader"}

View File

@ -29,13 +29,17 @@ def get_custom_dataset(
get_clevr_count_70k_sft_dataset,
)
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size, **kwargs)
return get_clevr_count_70k_sft_dataset(
path, split, processor, rank, world_size, **kwargs
)
elif "clevr_count_70k" in path and type == "rl":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset,
)
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size, **kwargs)
return get_clevr_count_70k_rl_dataset(
path, split, processor, rank, world_size, **kwargs
)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {type} is not supported. "

View File

@ -96,7 +96,7 @@ class BaseHFEngine(TrainEngine):
dtype = getattr(torch, self.config.dtype)
if self.is_vision_model:
if dtype== torch.float16:
if dtype == torch.float16:
raise ValueError(
"Vision models do not support float16 dtype. Please use bfloat16."
)

View File

@ -166,7 +166,7 @@ def concat_padded_tensors(
tensor = torch.cat([tensor, padding], dim=1)
tensors_to_concat.append(tensor)
result[key] = torch.cat(tensors_to_concat, dim=0)
result[key] = torch.cat(tensors_to_concat, dim=0)
return TensorDict(result, batch_size=new_batch_size)

View File

@ -1,7 +1,5 @@
import torch
VALID_VISION_MODELS = [
"qwen2_vl",
"qwen2_5_vl",

View File

@ -7,6 +7,7 @@ import torch.distributed as dist
import wandb
from torchdata.stateful_dataloader import StatefulDataLoader
from AReaL.arealite.workflow.Visionrlvr import VisionRLVRWorkflow
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset
@ -16,10 +17,10 @@ from arealite.utils.device import log_gpu_stats
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from AReaL.arealite.workflow.Visionrlvr import VisionRLVRWorkflow
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import stats_tracker
def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
@ -27,6 +28,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
return ""
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
@ -52,6 +54,7 @@ def clevr_count_70k_reward_fn(
return 0
def main(args):
wandb.init(project="clevr_70k")

View File

@ -1555,4 +1555,4 @@ def print_runtime_helper(args):
)
# Configuration options section
print_config_values(args)
print_config_values(args)

View File

@ -43,7 +43,8 @@ def gpu_count():
continue
else:
break
return cnt
return cnt
def set_cuda_device(device):
"""Set the default cuda-device.

View File

@ -170,4 +170,4 @@ def make(args: "BaseExperimentConfig", **kwargs) -> SchedulerClient:
return LocalSchedulerClient(args)
else:
raise NotImplementedError(f"Scheduler {args.mode} not found")
raise NotImplementedError(f"Scheduler {args.mode} not found")

View File

@ -122,7 +122,7 @@ class LocalSchedulerClient(SchedulerClient):
gpu > 0
), "All workers of the same type must either use GPU or not use GPU."
else:
self._job_with_gpu[worker_type] = (gpu > 0)
self._job_with_gpu[worker_type] = gpu > 0
if worker_type in self._job_env_vars:
assert (

View File

@ -174,4 +174,4 @@ class NameResolvingZmqPuller(ZMQJsonPuller):
)
addr = f"{host}:{port}"
name_resolve.add(name, addr)
super().__init__(host, port, **kwargs)
super().__init__(host, port, **kwargs)