reformatted

This commit is contained in:
朱晗 2025-07-23 13:58:25 +08:00
parent eff8f09149
commit 6bde86a934
18 changed files with 220 additions and 154 deletions

View File

@ -690,7 +690,7 @@ class ClusterSpecConfig:
@dataclass
class DatasetConfig:
path: str =field(
path: str = field(
default=MISSING,
metadata={
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."

View File

@ -47,14 +47,17 @@ class LLMResponse:
@property
def output_len(self) -> int:
return len(self.output_tokens)
@dataclass
class VLMRequest(LLMRequest):
image_data: Optional[List[ImageObject|str]] = field(default_factory=list)
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)
@dataclass
class VLMResponse(LLMResponse):
input_images: List[ImageObject|str] = field(default_factory=list)
input_images: List[ImageObject | str] = field(default_factory=list)
@dataclass
class FinetuneSpec:
@ -142,6 +145,7 @@ class AllocationMode:
raise ValueError(
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
)
@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
pattern = re.compile(

View File

@ -4,36 +4,39 @@ import transformers
VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
def get_custom_dataset(
path: str,
rank: int,
world_size: int,
training_type: str= "sft",
training_type: str = "sft",
split: Optional[str] = None,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
processor: Optional[transformers.AutoProcessor] = None,
):
):
if "gsm8k" in path and training_type == "sft":
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset
return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size)
elif "gsm8k" in path and training_type == "rl":
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset
return get_gsm8k_rl_dataset(path, split, rank, world_size)
elif "clevr_count_70k" in path and training_type == "sft":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_sft_dataset,
)
return get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size)
elif "clevr_count_70k" in path and training_type == "rl":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset,
)
return get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size)
return get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {training_type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. "
)

View File

@ -55,7 +55,7 @@ class BaseHFEngine(TrainEngine):
self.own_global_group = False
self._parallelism_group: dist.ProcessGroup
self.weight_update_group_initialized = False
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
@ -96,8 +96,10 @@ class BaseHFEngine(TrainEngine):
dtype = getattr(torch, self.config.dtype)
if self.is_vision_model:
dtype = torch.bfloat16
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(self.config.path)
dtype = torch.bfloat16
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
self.config.path
)
tik = time.perf_counter()
with torch.device("cuda"):
@ -132,7 +134,7 @@ class BaseHFEngine(TrainEngine):
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
@ -231,14 +233,12 @@ class BaseHFEngine(TrainEngine):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if self.is_vision_model:
assert "pixel_values" in input_ and "image_grid_thw" in input_, (
"For vision-language models, pixel_values and image_grid_thw must be present in input_"
)
assert (
"pixel_values" in input_ and "image_grid_thw" in input_
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
@ -303,7 +303,6 @@ class BaseHFEngine(TrainEngine):
loss *= loss_scale
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.optimizer_config.gradient_clipping,

View File

@ -49,7 +49,6 @@ class FSDPEngine(BaseHFEngine):
self.create_process_group()
self.create_device_model()
# Wrap with FSDP2
# Simple auto wrap policy
@ -101,7 +100,10 @@ class FSDPEngine(BaseHFEngine):
self.load_optimizer_state(meta.path)
def _save_model_to_hf(
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast], processor: Optional[AutoProcessor]
self,
path: str,
tokenizer: Optional[PreTrainedTokenizerFast],
processor: Optional[AutoProcessor],
):
"""Save model in HuggingFace format."""
if self.model is None:
@ -147,7 +149,7 @@ class FSDPEngine(BaseHFEngine):
dist.barrier()
torch.cuda.synchronize()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer,self.processor)
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
# dist.barrier() are called when _save_model_to_hf finished
if dist.get_rank() == 0:
update_name = names.update_weights_from_disk(
@ -240,7 +242,6 @@ class FSDPEngine(BaseHFEngine):
loss *= loss_scale
loss.backward()
# NOTE: grad norm clip function is different

View File

@ -257,8 +257,6 @@ class FSDPPPOActor(FSDPEngine):
return self.actor.ppo_update(*args, **kwargs)
def grpo_loss_fn(
logits: torch.Tensor,
input_data: Dict,

View File

@ -41,6 +41,7 @@ class FSDPLMEngine(FSDPEngine):
def evaluate_lm(self, data):
return self.lm_engine.evaluate_lm(data)
def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
@ -49,7 +50,7 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
logprobs = torch.where(loss_mask, logprobs, 0)
loss = -logprobs.sum() / loss_mask.count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(

View File

@ -228,7 +228,9 @@ class RemoteSGLangEngine(InferenceEngine):
return server
raise NotImplementedError("Only round-robin scheduling is implemented.")
async def agenerate(self, req: LLMRequest|VLMRequest) -> LLMResponse|VLMResponse:
async def agenerate(
self, req: LLMRequest | VLMRequest
) -> LLMResponse | VLMResponse:
"""Async version of generate using aiohttp."""
# Prepare request payload
gconfig = req.gconfig
@ -318,28 +320,28 @@ class RemoteSGLangEngine(InferenceEngine):
sample_params["max_new_tokens"] -= len(output_tokens)
latency = time.perf_counter() - start_time
if isinstance(req, VLMRequest):
response = VLMResponse(
input_tokens=req.input_ids,
input_images=req.image_data,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
input_tokens=req.input_ids,
input_images=req.image_data,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
else:
response=LLMResponse(
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
response = LLMResponse(
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
return response
def update_weights(self, meta):
@ -526,7 +528,7 @@ class RemoteSGLangEngine(InferenceEngine):
):
try:
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)

View File

@ -65,8 +65,9 @@ def pad_sequences_to_tensors(
return TensorDict()
skip_keys = {"pixel_values", "image_grid_thw"}
max_length = max(
len(seq) for item in sequence_list
for key, seq in item.items()
len(seq)
for item in sequence_list
for key, seq in item.items()
if key not in skip_keys
)
result = {}
@ -79,14 +80,18 @@ def pad_sequences_to_tensors(
x = item[key]
if not torch.is_tensor(x):
x = torch.tensor(x)
padded_x=torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value
)
padded_x = torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value
)
padded.append(padded_x)
result[key] = torch.stack(padded)
attention_mask = [
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
+ [0] * (max_length - len(next(iter(item[key] for key in item.keys() if key not in skip_keys))))
+ [0]
* (
max_length
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
)
for item in sequence_list
]
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
@ -139,7 +144,7 @@ def concat_padded_tensors(
tensors_to_concat.append(tensor)
continue
current_length = tensor.shape[1]
if key == "pixel_values" or key== "image_grid_thw":
if key == "pixel_values" or key == "image_grid_thw":
tensors_to_concat.append(tensor)
continue
if current_length < max_length:
@ -150,7 +155,7 @@ def concat_padded_tensors(
padding = torch.zeros(
(tensor.shape[0], pad_width), dtype=tensor.dtype
)
else:
# Pad feature tensors with pad_value
padding = torch.full(
@ -323,7 +328,7 @@ def split_padded_tensor_dict_into_mb_list(
to_split = {}
not_to_split = {}
for key, value in data.items():
if key=="image_grid_thw" or key=="pixel_values":
if key == "image_grid_thw" or key == "pixel_values":
continue
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
not_to_split[key] = value
@ -368,7 +373,7 @@ def split_padded_tensor_dict_into_mb_list(
for group_index in group_indices:
group_pixel_values = [pixel_values[i] for i in group_index]
group_image_grid_thw = [image_grid_thw[i].squeeze()for i in group_index]
group_image_grid_thw = [image_grid_thw[i].squeeze() for i in group_index]
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
pixel_values_split.append(torch.stack(group_pixel_values))

View File

@ -46,7 +46,7 @@ def fsdp2_clip_grad_norm_(
grads = [p.grad for p in parameters if p.grad is not None]
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
return total_norm

View File

@ -7,20 +7,17 @@ from typing import List
from PIL.Image import Image as ImageObject
def image2base64(images: List[ImageObject]|ImageObject)-> List[str]|str:
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
if isinstance(images, ImageObject):
images = [images]
byte_images = []
for image in images:
with BytesIO() as buffer:
image.save(buffer, format="PNG")
buffer.seek(0)
byte_image = base64.b64encode(buffer.read()).decode('utf-8')
buffer.seek(0)
byte_image = base64.b64encode(buffer.read()).decode("utf-8")
byte_images.append(byte_image)
return byte_images

View File

@ -5,6 +5,7 @@ VALID_VISION_MODELS = [
"qwen2_5_vl",
]
# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():

View File

@ -34,7 +34,7 @@ class VL_RLVRWorkflow(RLVRWorkflow):
return_tensors="pt",
)
input_ids=processed_input["input_ids"].tolist()[0]
input_ids = processed_input["input_ids"].tolist()[0]
n_samples = self.gconfig.n_samples
@ -62,13 +62,19 @@ class VL_RLVRWorkflow(RLVRWorkflow):
completion_ids=resp.output_tokens,
**data,
)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
pixel_values=processed_input["pixel_values"].clone().detach().unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"].clone().detach().unsqueeze(0),
pixel_values=processed_input["pixel_values"]
.clone()
.detach()
.unsqueeze(0),
image_grid_thw=processed_input["image_grid_thw"]
.clone()
.detach()
.unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),

View File

@ -24,7 +24,7 @@ from realhf.base import stats_tracker
def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
return match[-1]
return match[-1]
return ""
@ -56,32 +56,35 @@ def extract_solution(solution_str, method="strict") -> str | None:
break
return final_answer
def clevr_count_70k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
is_thinking = "thinking" in completions.lower()
sol = extract_answer(completions, data_name="") # str number
ans =answer
sol = extract_answer(completions, data_name="") # str number
ans = answer
if sol is None:
return 0
if ans is None:
return 0
if sol.strip() == ans.strip():
print(f"completions: {completions}, answer: {answer}")
if is_thinking:
return 1
return 1
else:
return 1
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
return 0.05
return 0
def main(args):
os.environ["WANDB_API_KEY"]=""
os.environ["WANDB_API_KEY"] = ""
wandb.init(project="clevr_70k")
config, _ = load_expr_config(args, GRPOConfig)
@ -90,22 +93,22 @@ def main(args):
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset=get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="rl",
processor=processor
)
valid_dataset=get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="rl",
processor=processor
)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="rl",
processor=processor,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="rl",
processor=processor,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
train_dataset,
@ -142,7 +145,7 @@ def main(args):
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref =FSDPPPOActor(config=config.ref)
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
@ -200,7 +203,7 @@ def main(args):
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
with stats_tracker.record_timing("compute_advantage"):
@ -212,8 +215,8 @@ def main(args):
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
wandb.log({"actor_reward": stats[0]['grpo_actor/final_reward/avg']})
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
@ -252,7 +255,14 @@ def main(args):
cnt += 1
batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"].float().to(actor.device)
wandb.log({"eval_reward": rewards.mean().item(), "epoch": epoch, "step": step, "global_step": global_step})
wandb.log(
{
"eval_reward": rewards.mean().item(),
"epoch": epoch,
"step": step,
"global_step": global_step,
}
)
with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator(
n_seqs=torch.ones(
@ -281,6 +291,6 @@ def main(args):
actor.destroy()
wandb.finish()
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -22,25 +22,25 @@ def main_sft():
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset=get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="sft",
tokenizer=tokenizer,
processor=processor,
)
valid_dataset=get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="sft",
tokenizer=tokenizer,
processor=processor,
)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
world_size=world_size,
split="train",
training_type="sft",
tokenizer=tokenizer,
processor=processor,
)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
world_size=world_size,
split="test",
training_type="sft",
tokenizer=tokenizer,
processor=processor,
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
train_dataset,

View File

@ -9,20 +9,31 @@ from datasets.distributed import split_dataset_by_node
from PIL.Image import Image as ImageObject
def input_text(text:str):
def input_text(text: str):
return {"type": "input_text", "text": text}
def input_image(base64_image: str):
return {"type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}"}
def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[Dict[str, Any]]:
return {
"type": "input_image",
"image_url": f"data:image/jpeg;base64,{base64_image}",
}
def build_raw_message(
sample: Dict[str, Any], base64_images: list[str]
) -> list[Dict[str, Any]]:
raw_message = []
problem_parts = [part.strip() for part in sample["problem"].split("<image>") if part.strip()]
problem_parts = [
part.strip() for part in sample["problem"].split("<image>") if part.strip()
]
insert_list = []
for i, part in enumerate(problem_parts):
if i > 0 or sample["problem"].startswith("<image>"):
if i > 0 or sample["problem"].startswith("<image>"):
insert_list.append("image")
part = part.strip()
if part:
part = part.strip()
if part:
insert_list.append("text")
image_index = 0
text_index = 0
@ -40,17 +51,25 @@ def build_raw_message(sample: Dict[str, Any], base64_images: list[str]) -> list[
def encode_image(image_file):
return base64.b64encode(image_file).decode("utf-8")
def convert_image(
image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int]
image: Union[Dict[str, Any], ImageObject, str],
min_pixels: Optional[int],
max_pixels: Optional[int],
) -> ImageObject:
if max_pixels is not None and (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height))
if min_pixels is not None and (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
width, height = int(image.width * resize_factor), int(
image.height * resize_factor
)
image = image.resize((width, height))
if image.mode != "RGB":
@ -59,32 +78,34 @@ def convert_image(
image.save(output, format="JPEG")
return output.getvalue()
def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
'''
"""
"clevr_count_70k": {
"image_key": "images",
"question_key": "problem",
"answer_key": "answer"
},
'''
"""
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
tokenizer = processor.tokenizer
tokenizer = processor.tokenizer
def process_example(example, idx):
# Add query_id column
images = example["images"]
if 'qwen' in processor.image_processor.image_processor_type.lower():
image_token="<|vision_start|><|image_pad|><|vision_end|>"
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
example["problem"] = example["problem"].replace("<image>", image_token)
processed_images = []
for image in images:
processed_images.append(convert_image(image,113*113,336*336))
processed_images.append(convert_image(image, 113 * 113, 336 * 336))
example["images"] = processed_images
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
return example
dataset = dataset.map(
@ -93,8 +114,8 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
)
def _process(example):
text=example["seq"]
processed_input=processor(
text = example["seq"]
processed_input = processor(
text=[text],
images=example["images"],
padding=False,
@ -103,38 +124,52 @@ def get_clevr_count_70k_sft_dataset(path, split, processor, rank, world_size):
return_attention_mask=False,
)
example["input_ids"] =processed_input["input_ids"].squeeze(0)
example["input_ids"] = processed_input["input_ids"].squeeze(0)
example["pixel_values"] = processed_input["pixel_values"]
example["image_grid_thw"] = processed_input["image_grid_thw"]
answer_token = tokenizer.encode(example["answer"])
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token))+[1]*len(answer_token)
example["loss_mask"]=loss_mask
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
answer_token
)
example["loss_mask"] = loss_mask
return example
dataset = dataset.map(lambda x: _process(x),remove_columns=["images","seq","problem","answer"])
dataset = dataset.map(
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
)
return dataset
def get_clevr_count_70k_rl_dataset(path, split,processor, rank, world_size):
def get_clevr_count_70k_rl_dataset(path, split, processor, rank, world_size):
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
processed_images = [convert_image(image, 113*113, 336*336) for image in sample["images"]]
if 'qwen' in processor.image_processor.image_processor_type.lower():
image_token="<|vision_start|><|image_pad|><|vision_end|>"
processed_images = [
convert_image(image, 113 * 113, 336 * 336) for image in sample["images"]
]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
system_prompt = {
"role": "system",
"role": "system",
"content": (
"Solve the following question: count the number of items in the image and provide the final answer in [ ] format, ensuring that only the number is inside the brackets without any additional text or explanations. "
)
),
}
messages =[{"role": "user", "content": sample["problem"].replace("<image>", image_token)}]
messages = [
{
"role": "user",
"content": sample["problem"].replace("<image>", image_token),
}
]
messages.insert(0, system_prompt)
messages=processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
messages = processor.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return {"messages": messages, "images": processed_images}
dataset = dataset.map(process).remove_columns(["problem"])
return dataset
return dataset

View File

@ -5,6 +5,7 @@ from datasets.distributed import split_dataset_by_node
def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
@ -16,12 +17,14 @@ def get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size):
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
def get_gsm8k_rl_dataset(path,split, rank, world_size):
def get_gsm8k_rl_dataset(path, split, rank, world_size):
dataset = load_dataset(path=path, name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
return dataset
return dataset

View File

@ -68,6 +68,7 @@ def load_hf_tokenizer(
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
@lru_cache(maxsize=8)
def load_hf_processor_and_tokenizer(
model_name_or_path: str,