mirror of https://github.com/inclusionAI/AReaL
0724_merge3
This commit is contained in:
parent
e97e33fca8
commit
176ec4bb23
|
@ -697,7 +697,8 @@ class DatasetConfig:
|
|||
},
|
||||
)
|
||||
type: Optional[str] = field(
|
||||
default=None, metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."}
|
||||
default=None,
|
||||
metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."},
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=1, metadata={"help": "Batch size of the dataloader"}
|
||||
|
|
|
@ -29,13 +29,17 @@ def get_custom_dataset(
|
|||
get_clevr_count_70k_sft_dataset,
|
||||
)
|
||||
|
||||
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size, **kwargs)
|
||||
return get_clevr_count_70k_sft_dataset(
|
||||
path, split, processor, rank, world_size, **kwargs
|
||||
)
|
||||
elif "clevr_count_70k" in path and type == "rl":
|
||||
from examples.arealite.dataset.clevr_count_70k import (
|
||||
get_clevr_count_70k_rl_dataset,
|
||||
)
|
||||
|
||||
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size, **kwargs)
|
||||
return get_clevr_count_70k_rl_dataset(
|
||||
path, split, processor, rank, world_size, **kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dataset {path} with split {split} and training type {type} is not supported. "
|
||||
|
|
|
@ -96,7 +96,7 @@ class BaseHFEngine(TrainEngine):
|
|||
dtype = getattr(torch, self.config.dtype)
|
||||
|
||||
if self.is_vision_model:
|
||||
if dtype== torch.float16:
|
||||
if dtype == torch.float16:
|
||||
raise ValueError(
|
||||
"Vision models do not support float16 dtype. Please use bfloat16."
|
||||
)
|
||||
|
|
|
@ -166,7 +166,7 @@ def concat_padded_tensors(
|
|||
tensor = torch.cat([tensor, padding], dim=1)
|
||||
tensors_to_concat.append(tensor)
|
||||
|
||||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
return TensorDict(result, batch_size=new_batch_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import torch
|
||||
|
||||
|
||||
|
||||
VALID_VISION_MODELS = [
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.distributed as dist
|
|||
import wandb
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from AReaL.arealite.workflow.Visionrlvr import VisionRLVRWorkflow
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.dataset.__init__ import get_custom_dataset
|
||||
|
@ -16,10 +17,10 @@ from arealite.utils.device import log_gpu_stats
|
|||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from AReaL.arealite.workflow.Visionrlvr import VisionRLVRWorkflow
|
||||
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
||||
if match:
|
||||
|
@ -27,6 +28,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
|
|||
|
||||
return ""
|
||||
|
||||
|
||||
def clevr_count_70k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
|
@ -52,6 +54,7 @@ def clevr_count_70k_reward_fn(
|
|||
|
||||
return 0
|
||||
|
||||
|
||||
def main(args):
|
||||
wandb.init(project="clevr_70k")
|
||||
|
||||
|
|
|
@ -1555,4 +1555,4 @@ def print_runtime_helper(args):
|
|||
)
|
||||
|
||||
# Configuration options section
|
||||
print_config_values(args)
|
||||
print_config_values(args)
|
||||
|
|
|
@ -43,7 +43,8 @@ def gpu_count():
|
|||
continue
|
||||
else:
|
||||
break
|
||||
return cnt
|
||||
return cnt
|
||||
|
||||
|
||||
def set_cuda_device(device):
|
||||
"""Set the default cuda-device.
|
||||
|
|
|
@ -170,4 +170,4 @@ def make(args: "BaseExperimentConfig", **kwargs) -> SchedulerClient:
|
|||
|
||||
return LocalSchedulerClient(args)
|
||||
else:
|
||||
raise NotImplementedError(f"Scheduler {args.mode} not found")
|
||||
raise NotImplementedError(f"Scheduler {args.mode} not found")
|
||||
|
|
|
@ -122,7 +122,7 @@ class LocalSchedulerClient(SchedulerClient):
|
|||
gpu > 0
|
||||
), "All workers of the same type must either use GPU or not use GPU."
|
||||
else:
|
||||
self._job_with_gpu[worker_type] = (gpu > 0)
|
||||
self._job_with_gpu[worker_type] = gpu > 0
|
||||
|
||||
if worker_type in self._job_env_vars:
|
||||
assert (
|
||||
|
|
|
@ -174,4 +174,4 @@ class NameResolvingZmqPuller(ZMQJsonPuller):
|
|||
)
|
||||
addr = f"{host}:{port}"
|
||||
name_resolve.add(name, addr)
|
||||
super().__init__(host, port, **kwargs)
|
||||
super().__init__(host, port, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue