From c3c986ae76e18d14252b034c470a4117132c6200 Mon Sep 17 00:00:00 2001 From: "lichangye.lcy" Date: Thu, 31 Jul 2025 18:36:31 +0800 Subject: [PATCH] 0731_4 --- examples/arealite/dataset/geometry3k.py | 59 ++++++++++++++++++------- examples/arealite/reward/geometry3k.py | 8 ++-- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/examples/arealite/dataset/geometry3k.py b/examples/arealite/dataset/geometry3k.py index a8b49b7..7414d73 100644 --- a/examples/arealite/dataset/geometry3k.py +++ b/examples/arealite/dataset/geometry3k.py @@ -4,26 +4,50 @@ from typing import Any, Dict, Optional, Union from datasets import load_dataset from datasets.distributed import split_dataset_by_node +from PIL import Image from PIL.Image import Image as ImageObject def convert_image( - image: Union[Dict[str, Any], ImageObject, str], - max_pixels: Optional[int], -) -> ImageObject: - if max_pixels is not None and (image.width * image.height) > max_pixels: - resize_factor = math.sqrt(max_pixels / (image.width * image.height)) - width, height = int(image.width * resize_factor), int( - image.height * resize_factor - ) - image = image.resize((width, height)) - - if image.mode != "RGB": - image = image.convert("RGB") + image: Union[Dict[str, Any], Image.Image, str], + target_width: int, + target_height: int, +) -> Image.Image: + """ + Convert the image by padding it to the target width and height. + """ + # Get the current size of the image + width, height = image.size + + # Calculate padding for width and height + pad_width = max(target_width - width, 0) + pad_height = max(target_height - height, 0) + + # Calculate padding for left, right, top, bottom + left = pad_width // 2 + top = pad_height // 2 + + # Create a new image with target size and a white background + new_image = Image.new("RGB", (target_width, target_height), (255, 255, 255)) + + # Paste the original image into the center of the new image + new_image.paste(image, (left, top)) + with BytesIO() as output: - image.save(output, format="JPEG") + new_image.save(output, format="JPEG") return output.getvalue() +def get_max_image_size(dataset): + """ + Traverse the dataset to find the maximum width and height across all images. + """ + max_width, max_height = 0, 0 + for example in dataset: + for image in example["images"]: + width, height = image.size + max_width = max(max_width, width) + max_height = max(max_height, height) + return max_width, max_height def get_geometry3k_sft_dataset(path, split, processor, rank, world_size): """ @@ -36,6 +60,8 @@ def get_geometry3k_sft_dataset(path, split, processor, rank, world_size): dataset = load_dataset(path=path, split=split) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + max_width, max_height = get_max_image_size(dataset) + tokenizer = processor.tokenizer def process_example(example, idx): @@ -50,7 +76,7 @@ def get_geometry3k_sft_dataset(path, split, processor, rank, world_size): ) processed_images = [] for image in images: - processed_images.append(convert_image(image, 336 * 336)) + processed_images.append(convert_image(image, max_width, max_height)) example["images"] = processed_images example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token @@ -92,9 +118,12 @@ def get_geometry3k_rl_dataset(path, split, processor, rank, world_size): dataset = load_dataset(path=path, split=split) dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + + max_width, max_height = get_max_image_size(dataset) + def process(sample): processed_images = [ - convert_image(image, 1024 * 1024) for image in sample["images"] + convert_image(image, max_width, max_height) for image in sample["images"] ] if "qwen" in processor.image_processor.image_processor_type.lower(): image_token = "<|vision_start|><|image_pad|><|vision_end|>" diff --git a/examples/arealite/reward/geometry3k.py b/examples/arealite/reward/geometry3k.py index bfc0265..fd90669 100644 --- a/examples/arealite/reward/geometry3k.py +++ b/examples/arealite/reward/geometry3k.py @@ -1,9 +1,9 @@ import re def extract_answer(pred_str, data_name, use_last_number=True): - match = re.findall(r"\[([0-9\.]+)\]", pred_str) - if match: - return match[-1] + matches = re.findall(r"\[([^\]]+)\]", pred_str) + if matches: + return matches[-1] return "" @@ -19,7 +19,7 @@ def geometry3k_reward_fn( return 0 if ans is None: return 0 - print(f"sol: {sol}, ans: {ans}") + # print(f"sol: {sol}, ans: {ans}") from realhf.impl.dataset.math_parser import math_equal if math_equal(sol, ans): print(f"completions: {completions}, answer: {answer}")