mirror of https://github.com/inclusionAI/AReaL
[Feature] Switch dataset path / model path to HF location to ease community usage (#82)
* Update .gitignore and modify dataset paths in scripts for improved file management and compatibility with Hugging Face datasets. Additionally, refactor dataset loading functions to utilize load_hf_or_local_file for better flexibility. * Remove sglang subproject and update dataset path format in load_hf_or_local_file function for compatibility with Hugging Face datasets. * Refactor imports in grader.py and parser.py to include sympy for improved functionality.
This commit is contained in:
parent
b768e5ce3c
commit
de134b4a7a
|
@ -182,3 +182,5 @@ api_key.json
|
||||||
|
|
||||||
.vscode/
|
.vscode/
|
||||||
wandb/
|
wandb/
|
||||||
|
outputs/
|
||||||
|
sympy/
|
|
@ -2,12 +2,12 @@
|
||||||
python3 training/main_async_ppo.py \
|
python3 training/main_async_ppo.py \
|
||||||
n_nodes=1 n_gpus_per_node=8 \
|
n_nodes=1 n_gpus_per_node=8 \
|
||||||
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
||||||
cluster.fileroot=/storage/testing/experiments \
|
cluster.fileroot=experiments \
|
||||||
actor.type._class=qwen3 \
|
actor.type._class=qwen3 \
|
||||||
actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
actor.path=Qwen/Qwen3-1.7B \
|
||||||
ref.type._class=qwen3 \
|
ref.type._class=qwen3 \
|
||||||
ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
ref.path=Qwen/Qwen3-1.7B \
|
||||||
dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \
|
dataset.path=hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl \
|
||||||
dataset.train_bs_n_seqs=32 \
|
dataset.train_bs_n_seqs=32 \
|
||||||
group_size=8 \
|
group_size=8 \
|
||||||
ppo.gen.max_new_tokens=4096 \
|
ppo.gen.max_new_tokens=4096 \
|
||||||
|
@ -15,4 +15,4 @@ python3 training/main_async_ppo.py \
|
||||||
actor_train.mb_spec.max_tokens_per_mb=32768 \
|
actor_train.mb_spec.max_tokens_per_mb=32768 \
|
||||||
actor_inf.mb_spec.max_tokens_per_mb=32768 \
|
actor_inf.mb_spec.max_tokens_per_mb=32768 \
|
||||||
max_concurrent_rollouts=16 \
|
max_concurrent_rollouts=16 \
|
||||||
max_head_offpolicyness=4
|
max_head_offpolicyness=4
|
||||||
|
|
|
@ -14,10 +14,11 @@ from typing import Union
|
||||||
|
|
||||||
import regex
|
import regex
|
||||||
from latex2sympy2 import latex2sympy
|
from latex2sympy2 import latex2sympy
|
||||||
from sympy import N, simplify
|
|
||||||
from sympy.parsing.latex import parse_latex
|
from sympy.parsing.latex import parse_latex
|
||||||
from sympy.parsing.sympy_parser import parse_expr
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
|
|
||||||
|
from sympy import N, simplify
|
||||||
|
|
||||||
# from .parser import choice_answer_clean, strip_string
|
# from .parser import choice_answer_clean, strip_string
|
||||||
# from parser import choice_answer_clean
|
# from parser import choice_answer_clean
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,11 @@ import re
|
||||||
from typing import Any, Dict, Iterable, List, TypeVar, Union
|
from typing import Any, Dict, Iterable, List, TypeVar, Union
|
||||||
|
|
||||||
import regex
|
import regex
|
||||||
import sympy
|
|
||||||
from latex2sympy2 import latex2sympy
|
from latex2sympy2 import latex2sympy
|
||||||
from word2number import w2n
|
from word2number import w2n
|
||||||
|
|
||||||
|
import sympy
|
||||||
|
|
||||||
# from utils import *
|
# from utils import *
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from typing import Any, Dict, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from functioncall.code.verify import code_verify
|
from functioncall.code.verify import code_verify
|
||||||
|
from realhf.utils import load_hf_or_local_file
|
||||||
|
|
||||||
logger = logging.getLogger("function call")
|
logger = logging.getLogger("function call")
|
||||||
|
|
||||||
|
@ -223,6 +224,7 @@ def statics_result(result, query_ids):
|
||||||
def standard_dataset_eval(
|
def standard_dataset_eval(
|
||||||
dataset_path, code_count=0, test_case_batch_size=20, dry_run=False
|
dataset_path, code_count=0, test_case_batch_size=20, dry_run=False
|
||||||
):
|
):
|
||||||
|
dataset_path = load_hf_or_local_file(dataset_path)
|
||||||
id2info = defaultdict(dict)
|
id2info = defaultdict(dict)
|
||||||
generateds, query_ids = [], []
|
generateds, query_ids = [], []
|
||||||
cnt = 0
|
cnt = 0
|
||||||
|
|
|
@ -15,10 +15,11 @@ from typing import Union
|
||||||
|
|
||||||
import regex
|
import regex
|
||||||
from latex2sympy2 import latex2sympy
|
from latex2sympy2 import latex2sympy
|
||||||
from sympy import N, simplify
|
|
||||||
from sympy.parsing.latex import parse_latex
|
from sympy.parsing.latex import parse_latex
|
||||||
from sympy.parsing.sympy_parser import parse_expr
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
|
|
||||||
|
from sympy import N, simplify
|
||||||
|
|
||||||
# from .parser import choice_answer_clean, strip_string
|
# from .parser import choice_answer_clean, strip_string
|
||||||
# from parser import choice_answer_clean
|
# from parser import choice_answer_clean
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,11 @@ import re
|
||||||
from typing import Any, Dict, Iterable, List, TypeVar, Union
|
from typing import Any, Dict, Iterable, List, TypeVar, Union
|
||||||
|
|
||||||
import regex
|
import regex
|
||||||
import sympy
|
|
||||||
from latex2sympy2 import latex2sympy
|
from latex2sympy2 import latex2sympy
|
||||||
from word2number import w2n
|
from word2number import w2n
|
||||||
|
|
||||||
|
import sympy
|
||||||
|
|
||||||
# from utils import *
|
# from utils import *
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@ from realhf.api.cli_args import MicroBatchSpec
|
||||||
from realhf.api.core import config as config_api
|
from realhf.api.core import config as config_api
|
||||||
from realhf.base import constants, datapack, logging, seeding
|
from realhf.base import constants, datapack, logging, seeding
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
from realhf.base.cluster import spec as cluster_spec
|
||||||
|
from realhf.utils import load_hf_or_local_file
|
||||||
|
|
||||||
logger = logging.getLogger("api.data")
|
logger = logging.getLogger("api.data")
|
||||||
|
|
||||||
|
@ -756,6 +757,7 @@ def load_shuffle_split_dataset(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None,
|
dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None,
|
||||||
):
|
):
|
||||||
|
dataset_path = load_hf_or_local_file(dataset_path)
|
||||||
if dataset_path is not None:
|
if dataset_path is not None:
|
||||||
if dataset_path.endswith(".jsonl"):
|
if dataset_path.endswith(".jsonl"):
|
||||||
with open(dataset_path, "r") as f:
|
with open(dataset_path, "r") as f:
|
||||||
|
|
|
@ -14,6 +14,7 @@ import torch.utils.data
|
||||||
|
|
||||||
from realhf.api.core import data_api
|
from realhf.api.core import data_api
|
||||||
from realhf.base import logging
|
from realhf.base import logging
|
||||||
|
from realhf.utils import load_hf_or_local_file
|
||||||
|
|
||||||
logger = logging.getLogger("Math Code Dataset")
|
logger = logging.getLogger("Math Code Dataset")
|
||||||
|
|
||||||
|
@ -54,6 +55,7 @@ def check_code_metadata_entries(data):
|
||||||
|
|
||||||
def load_metadata(path):
|
def load_metadata(path):
|
||||||
assert str(path).endswith(".jsonl"), path
|
assert str(path).endswith(".jsonl"), path
|
||||||
|
path = load_hf_or_local_file(path)
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
data = [json.loads(l) for l in f.readlines()]
|
data = [json.loads(l) for l in f.readlines()]
|
||||||
|
|
||||||
|
|
|
@ -185,7 +185,7 @@ class GserverManager(AsyncWorker):
|
||||||
if "num_paused_requests" in res:
|
if "num_paused_requests" in res:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{res['num_paused_requests']} requests are interrupted "
|
f"{res['num_paused_requests']} requests are interrupted "
|
||||||
f"during updateing weights for server {server_index}: {server_url}"
|
f"during updating weights for server {server_index}: {server_url}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
def download_from_huggingface(repo_id: str, filename: str, revision: str = "main", repo_type: str = "dataset") -> str:
|
||||||
|
"""
|
||||||
|
Download a file from a HuggingFace Hub repository.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install huggingface_hub to use this function: pip install huggingface_hub")
|
||||||
|
|
||||||
|
return hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
revision=revision,
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_or_local_file(path: str) -> str:
|
||||||
|
"""
|
||||||
|
Load a file from a HuggingFace Hub repository or a local file.
|
||||||
|
hf://<org>/<repo>/<filename>
|
||||||
|
hf://<org>/<repo>@<revision>/<filename>
|
||||||
|
|
||||||
|
e.g,
|
||||||
|
hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl
|
||||||
|
=>
|
||||||
|
repo_type = dataset
|
||||||
|
repo_id = inclusionAI/AReaL-RL-Data
|
||||||
|
filename = data/boba_106k_0319.jsonl
|
||||||
|
revision = main
|
||||||
|
=>
|
||||||
|
/root/.cache/huggingface/hub/models--inclusionAI--AReaL-RL-Data/data/boba_106k_0319.jsonl
|
||||||
|
"""
|
||||||
|
if path.startswith("hf://") or path.startswith("hf-dataset://"):
|
||||||
|
repo_type = "dataset" if path.startswith("hf-dataset://") else "model"
|
||||||
|
hf_path = path.strip().split("://")[1]
|
||||||
|
hf_org, hf_repo, filename = hf_path.split("/", 2)
|
||||||
|
repo_id = f"{hf_org}/{hf_repo}"
|
||||||
|
revision = "main"
|
||||||
|
if "@" in repo_id:
|
||||||
|
repo_id, revision = repo_id.split("@", 1)
|
||||||
|
return download_from_huggingface(repo_id, filename, revision)
|
||||||
|
return path
|
Loading…
Reference in New Issue