mirror of https://github.com/inclusionAI/AReaL
[Feature] Switch dataset path / model path to HF location to ease community usage (#82)
* Update .gitignore and modify dataset paths in scripts for improved file management and compatibility with Hugging Face datasets. Additionally, refactor dataset loading functions to utilize load_hf_or_local_file for better flexibility. * Remove sglang subproject and update dataset path format in load_hf_or_local_file function for compatibility with Hugging Face datasets. * Refactor imports in grader.py and parser.py to include sympy for improved functionality.
This commit is contained in:
parent
b768e5ce3c
commit
de134b4a7a
|
@ -182,3 +182,5 @@ api_key.json
|
|||
|
||||
.vscode/
|
||||
wandb/
|
||||
outputs/
|
||||
sympy/
|
|
@ -2,12 +2,12 @@
|
|||
python3 training/main_async_ppo.py \
|
||||
n_nodes=1 n_gpus_per_node=8 \
|
||||
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
||||
cluster.fileroot=/storage/testing/experiments \
|
||||
cluster.fileroot=experiments \
|
||||
actor.type._class=qwen3 \
|
||||
actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
||||
actor.path=Qwen/Qwen3-1.7B \
|
||||
ref.type._class=qwen3 \
|
||||
ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
||||
dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \
|
||||
ref.path=Qwen/Qwen3-1.7B \
|
||||
dataset.path=hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl \
|
||||
dataset.train_bs_n_seqs=32 \
|
||||
group_size=8 \
|
||||
ppo.gen.max_new_tokens=4096 \
|
||||
|
|
|
@ -14,10 +14,11 @@ from typing import Union
|
|||
|
||||
import regex
|
||||
from latex2sympy2 import latex2sympy
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
from sympy import N, simplify
|
||||
|
||||
# from .parser import choice_answer_clean, strip_string
|
||||
# from parser import choice_answer_clean
|
||||
|
||||
|
|
|
@ -3,10 +3,11 @@ import re
|
|||
from typing import Any, Dict, Iterable, List, TypeVar, Union
|
||||
|
||||
import regex
|
||||
import sympy
|
||||
from latex2sympy2 import latex2sympy
|
||||
from word2number import w2n
|
||||
|
||||
import sympy
|
||||
|
||||
# from utils import *
|
||||
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from typing import Any, Dict, List
|
|||
import numpy as np
|
||||
|
||||
from functioncall.code.verify import code_verify
|
||||
from realhf.utils import load_hf_or_local_file
|
||||
|
||||
logger = logging.getLogger("function call")
|
||||
|
||||
|
@ -223,6 +224,7 @@ def statics_result(result, query_ids):
|
|||
def standard_dataset_eval(
|
||||
dataset_path, code_count=0, test_case_batch_size=20, dry_run=False
|
||||
):
|
||||
dataset_path = load_hf_or_local_file(dataset_path)
|
||||
id2info = defaultdict(dict)
|
||||
generateds, query_ids = [], []
|
||||
cnt = 0
|
||||
|
|
|
@ -15,10 +15,11 @@ from typing import Union
|
|||
|
||||
import regex
|
||||
from latex2sympy2 import latex2sympy
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
from sympy import N, simplify
|
||||
|
||||
# from .parser import choice_answer_clean, strip_string
|
||||
# from parser import choice_answer_clean
|
||||
|
||||
|
|
|
@ -5,10 +5,11 @@ import re
|
|||
from typing import Any, Dict, Iterable, List, TypeVar, Union
|
||||
|
||||
import regex
|
||||
import sympy
|
||||
from latex2sympy2 import latex2sympy
|
||||
from word2number import w2n
|
||||
|
||||
import sympy
|
||||
|
||||
# from utils import *
|
||||
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ from realhf.api.cli_args import MicroBatchSpec
|
|||
from realhf.api.core import config as config_api
|
||||
from realhf.base import constants, datapack, logging, seeding
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.utils import load_hf_or_local_file
|
||||
|
||||
logger = logging.getLogger("api.data")
|
||||
|
||||
|
@ -756,6 +757,7 @@ def load_shuffle_split_dataset(
|
|||
dataset_path: str,
|
||||
dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None,
|
||||
):
|
||||
dataset_path = load_hf_or_local_file(dataset_path)
|
||||
if dataset_path is not None:
|
||||
if dataset_path.endswith(".jsonl"):
|
||||
with open(dataset_path, "r") as f:
|
||||
|
|
|
@ -14,6 +14,7 @@ import torch.utils.data
|
|||
|
||||
from realhf.api.core import data_api
|
||||
from realhf.base import logging
|
||||
from realhf.utils import load_hf_or_local_file
|
||||
|
||||
logger = logging.getLogger("Math Code Dataset")
|
||||
|
||||
|
@ -54,6 +55,7 @@ def check_code_metadata_entries(data):
|
|||
|
||||
def load_metadata(path):
|
||||
assert str(path).endswith(".jsonl"), path
|
||||
path = load_hf_or_local_file(path)
|
||||
with open(path, "r") as f:
|
||||
data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
|
|
|
@ -185,7 +185,7 @@ class GserverManager(AsyncWorker):
|
|||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updateing weights for server {server_index}: {server_url}"
|
||||
f"during updating weights for server {server_index}: {server_url}"
|
||||
)
|
||||
return
|
||||
logger.warning(
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
def download_from_huggingface(repo_id: str, filename: str, revision: str = "main", repo_type: str = "dataset") -> str:
|
||||
"""
|
||||
Download a file from a HuggingFace Hub repository.
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
except ImportError:
|
||||
raise ImportError("Please install huggingface_hub to use this function: pip install huggingface_hub")
|
||||
|
||||
return hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
|
||||
def load_hf_or_local_file(path: str) -> str:
|
||||
"""
|
||||
Load a file from a HuggingFace Hub repository or a local file.
|
||||
hf://<org>/<repo>/<filename>
|
||||
hf://<org>/<repo>@<revision>/<filename>
|
||||
|
||||
e.g,
|
||||
hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl
|
||||
=>
|
||||
repo_type = dataset
|
||||
repo_id = inclusionAI/AReaL-RL-Data
|
||||
filename = data/boba_106k_0319.jsonl
|
||||
revision = main
|
||||
=>
|
||||
/root/.cache/huggingface/hub/models--inclusionAI--AReaL-RL-Data/data/boba_106k_0319.jsonl
|
||||
"""
|
||||
if path.startswith("hf://") or path.startswith("hf-dataset://"):
|
||||
repo_type = "dataset" if path.startswith("hf-dataset://") else "model"
|
||||
hf_path = path.strip().split("://")[1]
|
||||
hf_org, hf_repo, filename = hf_path.split("/", 2)
|
||||
repo_id = f"{hf_org}/{hf_repo}"
|
||||
revision = "main"
|
||||
if "@" in repo_id:
|
||||
repo_id, revision = repo_id.split("@", 1)
|
||||
return download_from_huggingface(repo_id, filename, revision)
|
||||
return path
|
Loading…
Reference in New Issue