diff --git a/.gitignore b/.gitignore index 33cb122..b43242f 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,5 @@ api_key.json .vscode/ wandb/ +outputs/ +sympy/ \ No newline at end of file diff --git a/examples/run_async_ppo.sh b/examples/run_async_ppo.sh index e6c7d43..4873a4d 100644 --- a/examples/run_async_ppo.sh +++ b/examples/run_async_ppo.sh @@ -2,12 +2,12 @@ python3 training/main_async_ppo.py \ n_nodes=1 n_gpus_per_node=8 \ allocation_mode=sglang.d4p1m1+d2p2m1 \ - cluster.fileroot=/storage/testing/experiments \ + cluster.fileroot=experiments \ actor.type._class=qwen3 \ - actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \ + actor.path=Qwen/Qwen3-1.7B \ ref.type._class=qwen3 \ - ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \ - dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \ + ref.path=Qwen/Qwen3-1.7B \ + dataset.path=hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl \ dataset.train_bs_n_seqs=32 \ group_size=8 \ ppo.gen.max_new_tokens=4096 \ @@ -15,4 +15,4 @@ python3 training/main_async_ppo.py \ actor_train.mb_spec.max_tokens_per_mb=32768 \ actor_inf.mb_spec.max_tokens_per_mb=32768 \ max_concurrent_rollouts=16 \ - max_head_offpolicyness=4 \ No newline at end of file + max_head_offpolicyness=4 diff --git a/functioncall/math/function/grader.py b/functioncall/math/function/grader.py index 19b83cf..12cb1d9 100644 --- a/functioncall/math/function/grader.py +++ b/functioncall/math/function/grader.py @@ -14,10 +14,11 @@ from typing import Union import regex from latex2sympy2 import latex2sympy -from sympy import N, simplify from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr +from sympy import N, simplify + # from .parser import choice_answer_clean, strip_string # from parser import choice_answer_clean diff --git a/functioncall/math/function/parser.py b/functioncall/math/function/parser.py index 6798304..5f0c867 100644 --- a/functioncall/math/function/parser.py +++ b/functioncall/math/function/parser.py @@ -3,10 +3,11 @@ import re from typing import Any, Dict, Iterable, List, TypeVar, Union import regex -import sympy from latex2sympy2 import latex2sympy from word2number import w2n +import sympy + # from utils import * diff --git a/functioncall/test/performance_eval.py b/functioncall/test/performance_eval.py index e9e1175..85e0f83 100644 --- a/functioncall/test/performance_eval.py +++ b/functioncall/test/performance_eval.py @@ -14,6 +14,7 @@ from typing import Any, Dict, List import numpy as np from functioncall.code.verify import code_verify +from realhf.utils import load_hf_or_local_file logger = logging.getLogger("function call") @@ -223,6 +224,7 @@ def statics_result(result, query_ids): def standard_dataset_eval( dataset_path, code_count=0, test_case_batch_size=20, dry_run=False ): + dataset_path = load_hf_or_local_file(dataset_path) id2info = defaultdict(dict) generateds, query_ids = [], [] cnt = 0 diff --git a/grader.py b/grader.py index 916194e..641b608 100644 --- a/grader.py +++ b/grader.py @@ -15,10 +15,11 @@ from typing import Union import regex from latex2sympy2 import latex2sympy -from sympy import N, simplify from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr +from sympy import N, simplify + # from .parser import choice_answer_clean, strip_string # from parser import choice_answer_clean diff --git a/parser.py b/parser.py index 403d6d8..1958e9a 100644 --- a/parser.py +++ b/parser.py @@ -5,10 +5,11 @@ import re from typing import Any, Dict, Iterable, List, TypeVar, Union import regex -import sympy from latex2sympy2 import latex2sympy from word2number import w2n +import sympy + # from utils import * diff --git a/realhf/api/core/data_api.py b/realhf/api/core/data_api.py index ca68a54..e2e5951 100644 --- a/realhf/api/core/data_api.py +++ b/realhf/api/core/data_api.py @@ -41,6 +41,7 @@ from realhf.api.cli_args import MicroBatchSpec from realhf.api.core import config as config_api from realhf.base import constants, datapack, logging, seeding from realhf.base.cluster import spec as cluster_spec +from realhf.utils import load_hf_or_local_file logger = logging.getLogger("api.data") @@ -756,6 +757,7 @@ def load_shuffle_split_dataset( dataset_path: str, dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None, ): + dataset_path = load_hf_or_local_file(dataset_path) if dataset_path is not None: if dataset_path.endswith(".jsonl"): with open(dataset_path, "r") as f: diff --git a/realhf/impl/dataset/math_code_dataset.py b/realhf/impl/dataset/math_code_dataset.py index 5c98adf..aa968ab 100644 --- a/realhf/impl/dataset/math_code_dataset.py +++ b/realhf/impl/dataset/math_code_dataset.py @@ -14,6 +14,7 @@ import torch.utils.data from realhf.api.core import data_api from realhf.base import logging +from realhf.utils import load_hf_or_local_file logger = logging.getLogger("Math Code Dataset") @@ -54,6 +55,7 @@ def check_code_metadata_entries(data): def load_metadata(path): assert str(path).endswith(".jsonl"), path + path = load_hf_or_local_file(path) with open(path, "r") as f: data = [json.loads(l) for l in f.readlines()] diff --git a/realhf/system/gserver_manager.py b/realhf/system/gserver_manager.py index 6a14093..aabdefc 100644 --- a/realhf/system/gserver_manager.py +++ b/realhf/system/gserver_manager.py @@ -185,7 +185,7 @@ class GserverManager(AsyncWorker): if "num_paused_requests" in res: logger.info( f"{res['num_paused_requests']} requests are interrupted " - f"during updateing weights for server {server_index}: {server_url}" + f"during updating weights for server {server_index}: {server_url}" ) return logger.warning( diff --git a/realhf/utils.py b/realhf/utils.py new file mode 100644 index 0000000..de85b1b --- /dev/null +++ b/realhf/utils.py @@ -0,0 +1,43 @@ +def download_from_huggingface(repo_id: str, filename: str, revision: str = "main", repo_type: str = "dataset") -> str: + """ + Download a file from a HuggingFace Hub repository. + """ + try: + from huggingface_hub import hf_hub_download + except ImportError: + raise ImportError("Please install huggingface_hub to use this function: pip install huggingface_hub") + + return hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type=repo_type, + ) + + +def load_hf_or_local_file(path: str) -> str: + """ + Load a file from a HuggingFace Hub repository or a local file. + hf://// + hf:///@/ + + e.g, + hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl + => + repo_type = dataset + repo_id = inclusionAI/AReaL-RL-Data + filename = data/boba_106k_0319.jsonl + revision = main + => + /root/.cache/huggingface/hub/models--inclusionAI--AReaL-RL-Data/data/boba_106k_0319.jsonl + """ + if path.startswith("hf://") or path.startswith("hf-dataset://"): + repo_type = "dataset" if path.startswith("hf-dataset://") else "model" + hf_path = path.strip().split("://")[1] + hf_org, hf_repo, filename = hf_path.split("/", 2) + repo_id = f"{hf_org}/{hf_repo}" + revision = "main" + if "@" in repo_id: + repo_id, revision = repo_id.split("@", 1) + return download_from_huggingface(repo_id, filename, revision) + return path \ No newline at end of file