mirror of https://github.com/inclusionAI/AReaL
0731_4
This commit is contained in:
parent
2e0af5dd87
commit
c3c986ae76
|
@ -4,26 +4,50 @@ from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from datasets.distributed import split_dataset_by_node
|
from datasets.distributed import split_dataset_by_node
|
||||||
|
from PIL import Image
|
||||||
from PIL.Image import Image as ImageObject
|
from PIL.Image import Image as ImageObject
|
||||||
|
|
||||||
|
|
||||||
def convert_image(
|
def convert_image(
|
||||||
image: Union[Dict[str, Any], ImageObject, str],
|
image: Union[Dict[str, Any], Image.Image, str],
|
||||||
max_pixels: Optional[int],
|
target_width: int,
|
||||||
) -> ImageObject:
|
target_height: int,
|
||||||
if max_pixels is not None and (image.width * image.height) > max_pixels:
|
) -> Image.Image:
|
||||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
"""
|
||||||
width, height = int(image.width * resize_factor), int(
|
Convert the image by padding it to the target width and height.
|
||||||
image.height * resize_factor
|
"""
|
||||||
)
|
# Get the current size of the image
|
||||||
image = image.resize((width, height))
|
width, height = image.size
|
||||||
|
|
||||||
if image.mode != "RGB":
|
# Calculate padding for width and height
|
||||||
image = image.convert("RGB")
|
pad_width = max(target_width - width, 0)
|
||||||
|
pad_height = max(target_height - height, 0)
|
||||||
|
|
||||||
|
# Calculate padding for left, right, top, bottom
|
||||||
|
left = pad_width // 2
|
||||||
|
top = pad_height // 2
|
||||||
|
|
||||||
|
# Create a new image with target size and a white background
|
||||||
|
new_image = Image.new("RGB", (target_width, target_height), (255, 255, 255))
|
||||||
|
|
||||||
|
# Paste the original image into the center of the new image
|
||||||
|
new_image.paste(image, (left, top))
|
||||||
|
|
||||||
with BytesIO() as output:
|
with BytesIO() as output:
|
||||||
image.save(output, format="JPEG")
|
new_image.save(output, format="JPEG")
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
|
def get_max_image_size(dataset):
|
||||||
|
"""
|
||||||
|
Traverse the dataset to find the maximum width and height across all images.
|
||||||
|
"""
|
||||||
|
max_width, max_height = 0, 0
|
||||||
|
for example in dataset:
|
||||||
|
for image in example["images"]:
|
||||||
|
width, height = image.size
|
||||||
|
max_width = max(max_width, width)
|
||||||
|
max_height = max(max_height, height)
|
||||||
|
return max_width, max_height
|
||||||
|
|
||||||
def get_geometry3k_sft_dataset(path, split, processor, rank, world_size):
|
def get_geometry3k_sft_dataset(path, split, processor, rank, world_size):
|
||||||
"""
|
"""
|
||||||
|
@ -36,6 +60,8 @@ def get_geometry3k_sft_dataset(path, split, processor, rank, world_size):
|
||||||
dataset = load_dataset(path=path, split=split)
|
dataset = load_dataset(path=path, split=split)
|
||||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
max_width, max_height = get_max_image_size(dataset)
|
||||||
|
|
||||||
tokenizer = processor.tokenizer
|
tokenizer = processor.tokenizer
|
||||||
|
|
||||||
def process_example(example, idx):
|
def process_example(example, idx):
|
||||||
|
@ -50,7 +76,7 @@ def get_geometry3k_sft_dataset(path, split, processor, rank, world_size):
|
||||||
)
|
)
|
||||||
processed_images = []
|
processed_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
processed_images.append(convert_image(image, 336 * 336))
|
processed_images.append(convert_image(image, max_width, max_height))
|
||||||
example["images"] = processed_images
|
example["images"] = processed_images
|
||||||
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
|
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token
|
||||||
|
|
||||||
|
@ -92,9 +118,12 @@ def get_geometry3k_rl_dataset(path, split, processor, rank, world_size):
|
||||||
dataset = load_dataset(path=path, split=split)
|
dataset = load_dataset(path=path, split=split)
|
||||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
max_width, max_height = get_max_image_size(dataset)
|
||||||
|
|
||||||
def process(sample):
|
def process(sample):
|
||||||
processed_images = [
|
processed_images = [
|
||||||
convert_image(image, 1024 * 1024) for image in sample["images"]
|
convert_image(image, max_width, max_height) for image in sample["images"]
|
||||||
]
|
]
|
||||||
if "qwen" in processor.image_processor.image_processor_type.lower():
|
if "qwen" in processor.image_processor.image_processor_type.lower():
|
||||||
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import re
|
import re
|
||||||
|
|
||||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||||
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
|
matches = re.findall(r"\[([^\]]+)\]", pred_str)
|
||||||
if match:
|
if matches:
|
||||||
return match[-1]
|
return matches[-1]
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ def geometry3k_reward_fn(
|
||||||
return 0
|
return 0
|
||||||
if ans is None:
|
if ans is None:
|
||||||
return 0
|
return 0
|
||||||
print(f"sol: {sol}, ans: {ans}")
|
# print(f"sol: {sol}, ans: {ans}")
|
||||||
from realhf.impl.dataset.math_parser import math_equal
|
from realhf.impl.dataset.math_parser import math_equal
|
||||||
if math_equal(sol, ans):
|
if math_equal(sol, ans):
|
||||||
print(f"completions: {completions}, answer: {answer}")
|
print(f"completions: {completions}, answer: {answer}")
|
||||||
|
|
Loading…
Reference in New Issue