[Feature] Switch dataset path / model path to HF location to ease community usage (#82)

* Update .gitignore and modify dataset paths in scripts for improved file management and compatibility with Hugging Face datasets. Additionally, refactor dataset loading functions to utilize load_hf_or_local_file for better flexibility.

* Remove sglang subproject and update dataset path format in load_hf_or_local_file function for compatibility with Hugging Face datasets.

* Refactor imports in grader.py and parser.py to include sympy for improved functionality.
This commit is contained in:
Ligeng Zhu 2025-06-06 21:38:06 +08:00 committed by GitHub
parent b768e5ce3c
commit de134b4a7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 65 additions and 10 deletions

2
.gitignore vendored
View File

@ -182,3 +182,5 @@ api_key.json
.vscode/ .vscode/
wandb/ wandb/
outputs/
sympy/

View File

@ -2,12 +2,12 @@
python3 training/main_async_ppo.py \ python3 training/main_async_ppo.py \
n_nodes=1 n_gpus_per_node=8 \ n_nodes=1 n_gpus_per_node=8 \
allocation_mode=sglang.d4p1m1+d2p2m1 \ allocation_mode=sglang.d4p1m1+d2p2m1 \
cluster.fileroot=/storage/testing/experiments \ cluster.fileroot=experiments \
actor.type._class=qwen3 \ actor.type._class=qwen3 \
actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \ actor.path=Qwen/Qwen3-1.7B \
ref.type._class=qwen3 \ ref.type._class=qwen3 \
ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \ ref.path=Qwen/Qwen3-1.7B \
dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \ dataset.path=hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl \
dataset.train_bs_n_seqs=32 \ dataset.train_bs_n_seqs=32 \
group_size=8 \ group_size=8 \
ppo.gen.max_new_tokens=4096 \ ppo.gen.max_new_tokens=4096 \

View File

@ -14,10 +14,11 @@ from typing import Union
import regex import regex
from latex2sympy2 import latex2sympy from latex2sympy2 import latex2sympy
from sympy import N, simplify
from sympy.parsing.latex import parse_latex from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.sympy_parser import parse_expr
from sympy import N, simplify
# from .parser import choice_answer_clean, strip_string # from .parser import choice_answer_clean, strip_string
# from parser import choice_answer_clean # from parser import choice_answer_clean

View File

@ -3,10 +3,11 @@ import re
from typing import Any, Dict, Iterable, List, TypeVar, Union from typing import Any, Dict, Iterable, List, TypeVar, Union
import regex import regex
import sympy
from latex2sympy2 import latex2sympy from latex2sympy2 import latex2sympy
from word2number import w2n from word2number import w2n
import sympy
# from utils import * # from utils import *

View File

@ -14,6 +14,7 @@ from typing import Any, Dict, List
import numpy as np import numpy as np
from functioncall.code.verify import code_verify from functioncall.code.verify import code_verify
from realhf.utils import load_hf_or_local_file
logger = logging.getLogger("function call") logger = logging.getLogger("function call")
@ -223,6 +224,7 @@ def statics_result(result, query_ids):
def standard_dataset_eval( def standard_dataset_eval(
dataset_path, code_count=0, test_case_batch_size=20, dry_run=False dataset_path, code_count=0, test_case_batch_size=20, dry_run=False
): ):
dataset_path = load_hf_or_local_file(dataset_path)
id2info = defaultdict(dict) id2info = defaultdict(dict)
generateds, query_ids = [], [] generateds, query_ids = [], []
cnt = 0 cnt = 0

View File

@ -15,10 +15,11 @@ from typing import Union
import regex import regex
from latex2sympy2 import latex2sympy from latex2sympy2 import latex2sympy
from sympy import N, simplify
from sympy.parsing.latex import parse_latex from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.sympy_parser import parse_expr
from sympy import N, simplify
# from .parser import choice_answer_clean, strip_string # from .parser import choice_answer_clean, strip_string
# from parser import choice_answer_clean # from parser import choice_answer_clean

View File

@ -5,10 +5,11 @@ import re
from typing import Any, Dict, Iterable, List, TypeVar, Union from typing import Any, Dict, Iterable, List, TypeVar, Union
import regex import regex
import sympy
from latex2sympy2 import latex2sympy from latex2sympy2 import latex2sympy
from word2number import w2n from word2number import w2n
import sympy
# from utils import * # from utils import *

View File

@ -41,6 +41,7 @@ from realhf.api.cli_args import MicroBatchSpec
from realhf.api.core import config as config_api from realhf.api.core import config as config_api
from realhf.base import constants, datapack, logging, seeding from realhf.base import constants, datapack, logging, seeding
from realhf.base.cluster import spec as cluster_spec from realhf.base.cluster import spec as cluster_spec
from realhf.utils import load_hf_or_local_file
logger = logging.getLogger("api.data") logger = logging.getLogger("api.data")
@ -756,6 +757,7 @@ def load_shuffle_split_dataset(
dataset_path: str, dataset_path: str,
dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None, dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None,
): ):
dataset_path = load_hf_or_local_file(dataset_path)
if dataset_path is not None: if dataset_path is not None:
if dataset_path.endswith(".jsonl"): if dataset_path.endswith(".jsonl"):
with open(dataset_path, "r") as f: with open(dataset_path, "r") as f:

View File

@ -14,6 +14,7 @@ import torch.utils.data
from realhf.api.core import data_api from realhf.api.core import data_api
from realhf.base import logging from realhf.base import logging
from realhf.utils import load_hf_or_local_file
logger = logging.getLogger("Math Code Dataset") logger = logging.getLogger("Math Code Dataset")
@ -54,6 +55,7 @@ def check_code_metadata_entries(data):
def load_metadata(path): def load_metadata(path):
assert str(path).endswith(".jsonl"), path assert str(path).endswith(".jsonl"), path
path = load_hf_or_local_file(path)
with open(path, "r") as f: with open(path, "r") as f:
data = [json.loads(l) for l in f.readlines()] data = [json.loads(l) for l in f.readlines()]

View File

@ -185,7 +185,7 @@ class GserverManager(AsyncWorker):
if "num_paused_requests" in res: if "num_paused_requests" in res:
logger.info( logger.info(
f"{res['num_paused_requests']} requests are interrupted " f"{res['num_paused_requests']} requests are interrupted "
f"during updateing weights for server {server_index}: {server_url}" f"during updating weights for server {server_index}: {server_url}"
) )
return return
logger.warning( logger.warning(

43
realhf/utils.py Normal file
View File

@ -0,0 +1,43 @@
def download_from_huggingface(repo_id: str, filename: str, revision: str = "main", repo_type: str = "dataset") -> str:
"""
Download a file from a HuggingFace Hub repository.
"""
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError("Please install huggingface_hub to use this function: pip install huggingface_hub")
return hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
repo_type=repo_type,
)
def load_hf_or_local_file(path: str) -> str:
"""
Load a file from a HuggingFace Hub repository or a local file.
hf://<org>/<repo>/<filename>
hf://<org>/<repo>@<revision>/<filename>
e.g,
hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl
=>
repo_type = dataset
repo_id = inclusionAI/AReaL-RL-Data
filename = data/boba_106k_0319.jsonl
revision = main
=>
/root/.cache/huggingface/hub/models--inclusionAI--AReaL-RL-Data/data/boba_106k_0319.jsonl
"""
if path.startswith("hf://") or path.startswith("hf-dataset://"):
repo_type = "dataset" if path.startswith("hf-dataset://") else "model"
hf_path = path.strip().split("://")[1]
hf_org, hf_repo, filename = hf_path.split("/", 2)
repo_id = f"{hf_org}/{hf_repo}"
revision = "main"
if "@" in repo_id:
repo_id, revision = repo_id.split("@", 1)
return download_from_huggingface(repo_id, filename, revision)
return path