This commit is contained in:
lichangye.lcy 2025-07-31 18:36:31 +08:00
parent 2e0af5dd87
commit c3c986ae76
2 changed files with 48 additions and 19 deletions

View File

@ -4,26 +4,50 @@ from typing import Any, Dict, Optional, Union
from datasets import load_dataset from datasets import load_dataset
from datasets.distributed import split_dataset_by_node from datasets.distributed import split_dataset_by_node
from PIL import Image
from PIL.Image import Image as ImageObject from PIL.Image import Image as ImageObject
def convert_image( def convert_image(
image: Union[Dict[str, Any], ImageObject, str], image: Union[Dict[str, Any], Image.Image, str],
max_pixels: Optional[int], target_width: int,
) -> ImageObject: target_height: int,
if max_pixels is not None and (image.width * image.height) > max_pixels: ) -> Image.Image:
resize_factor = math.sqrt(max_pixels / (image.width * image.height)) """
width, height = int(image.width * resize_factor), int( Convert the image by padding it to the target width and height.
image.height * resize_factor """
) # Get the current size of the image
image = image.resize((width, height)) width, height = image.size
if image.mode != "RGB": # Calculate padding for width and height
image = image.convert("RGB") pad_width = max(target_width - width, 0)
pad_height = max(target_height - height, 0)
# Calculate padding for left, right, top, bottom
left = pad_width // 2
top = pad_height // 2
# Create a new image with target size and a white background
new_image = Image.new("RGB", (target_width, target_height), (255, 255, 255))
# Paste the original image into the center of the new image
new_image.paste(image, (left, top))
with BytesIO() as output: with BytesIO() as output:
image.save(output, format="JPEG") new_image.save(output, format="JPEG")
return output.getvalue() return output.getvalue()
def get_max_image_size(dataset):
"""
Traverse the dataset to find the maximum width and height across all images.
"""
max_width, max_height = 0, 0
for example in dataset:
for image in example["images"]:
width, height = image.size
max_width = max(max_width, width)
max_height = max(max_height, height)
return max_width, max_height
def get_geometry3k_sft_dataset(path, split, processor, rank, world_size): def get_geometry3k_sft_dataset(path, split, processor, rank, world_size):
""" """
@ -36,6 +60,8 @@ def get_geometry3k_sft_dataset(path, split, processor, rank, world_size):
dataset = load_dataset(path=path, split=split) dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
max_width, max_height = get_max_image_size(dataset)
tokenizer = processor.tokenizer tokenizer = processor.tokenizer
def process_example(example, idx): def process_example(example, idx):
@ -50,7 +76,7 @@ def get_geometry3k_sft_dataset(path, split, processor, rank, world_size):
) )
processed_images = [] processed_images = []
for image in images: for image in images:
processed_images.append(convert_image(image, 336 * 336)) processed_images.append(convert_image(image, max_width, max_height))
example["images"] = processed_images example["images"] = processed_images
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
@ -92,9 +118,12 @@ def get_geometry3k_rl_dataset(path, split, processor, rank, world_size):
dataset = load_dataset(path=path, split=split) dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
max_width, max_height = get_max_image_size(dataset)
def process(sample): def process(sample):
processed_images = [ processed_images = [
convert_image(image, 1024 * 1024) for image in sample["images"] convert_image(image, max_width, max_height) for image in sample["images"]
] ]
if "qwen" in processor.image_processor.image_processor_type.lower(): if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>" image_token = "<|vision_start|><|image_pad|><|vision_end|>"

View File

@ -1,9 +1,9 @@
import re import re
def extract_answer(pred_str, data_name, use_last_number=True): def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str) matches = re.findall(r"\[([^\]]+)\]", pred_str)
if match: if matches:
return match[-1] return matches[-1]
return "" return ""
@ -19,7 +19,7 @@ def geometry3k_reward_fn(
return 0 return 0
if ans is None: if ans is None:
return 0 return 0
print(f"sol: {sol}, ans: {ans}") # print(f"sol: {sol}, ans: {ans}")
from realhf.impl.dataset.math_parser import math_equal from realhf.impl.dataset.math_parser import math_equal
if math_equal(sol, ans): if math_equal(sol, ans):
print(f"completions: {completions}, answer: {answer}") print(f"completions: {completions}, answer: {answer}")